添加了,nli,sentiment_classification, dialog_intent, dialog_modeling几个pipeline。同时加入了nlp里面sequence classification一些简单的抽象。 去掉了zero_shot_classification Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9159089master
| @@ -16,6 +16,7 @@ class Models(object): | |||||
| palm = 'palm-v2' | palm = 'palm-v2' | ||||
| structbert = 'structbert' | structbert = 'structbert' | ||||
| veco = 'veco' | veco = 'veco' | ||||
| space = 'space' | |||||
| # audio models | # audio models | ||||
| sambert_hifi_16k = 'sambert-hifi-16k' | sambert_hifi_16k = 'sambert-hifi-16k' | ||||
| @@ -52,7 +53,11 @@ class Pipelines(object): | |||||
| word_segmentation = 'word-segmentation' | word_segmentation = 'word-segmentation' | ||||
| text_generation = 'text-generation' | text_generation = 'text-generation' | ||||
| sentiment_analysis = 'sentiment-analysis' | sentiment_analysis = 'sentiment-analysis' | ||||
| sentiment_classification = 'sentiment-classification' | |||||
| fill_mask = 'fill-mask' | fill_mask = 'fill-mask' | ||||
| nli = 'nli' | |||||
| dialog_intent_prediction = 'dialog-intent-prediction' | |||||
| dialog_modeling = 'dialog-modeling' | |||||
| zero_shot_classification = 'zero-shot-classification' | zero_shot_classification = 'zero-shot-classification' | ||||
| # audio tasks | # audio tasks | ||||
| @@ -97,6 +102,11 @@ class Preprocessors(object): | |||||
| # nlp preprocessor | # nlp preprocessor | ||||
| bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' | bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' | ||||
| palm_text_gen_tokenizer = 'palm-text-gen-tokenizer' | palm_text_gen_tokenizer = 'palm-text-gen-tokenizer' | ||||
| token_cls_tokenizer = 'token-cls-tokenizer' | |||||
| nli_tokenizer = 'nli-tokenizer' | |||||
| sen_cls_tokenizer = 'sen-cls-tokenizer' | |||||
| dialog_intent_preprocessor = 'dialog-intent-preprocessor' | |||||
| dialog_modeling_preprocessor = 'dialog-modeling-preprocessor' | |||||
| sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' | sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' | ||||
| zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | ||||
| @@ -15,9 +15,13 @@ except ModuleNotFoundError as e: | |||||
| try: | try: | ||||
| from .audio.kws import GenericKeyWordSpotting | from .audio.kws import GenericKeyWordSpotting | ||||
| from .multi_modal import OfaForImageCaptioning | from .multi_modal import OfaForImageCaptioning | ||||
| from .nlp import (BertForSequenceClassification, | |||||
| SbertForSentenceSimilarity, | |||||
| SbertForZeroShotClassification) | |||||
| from .nlp import (BertForMaskedLM, BertForSequenceClassification, | |||||
| SbertForNLI, SbertForSentenceSimilarity, | |||||
| SbertForSentimentClassification, | |||||
| SbertForTokenClassification, | |||||
| SbertForZeroShotClassification, SpaceForDialogIntent, | |||||
| SpaceForDialogModeling, StructBertForMaskedLM, | |||||
| VecoForMaskedLM) | |||||
| from .audio.ans.frcrn import FRCRNModel | from .audio.ans.frcrn import FRCRNModel | ||||
| except ModuleNotFoundError as e: | except ModuleNotFoundError as e: | ||||
| if str(e) == "No module named 'pytorch'": | if str(e) == "No module named 'pytorch'": | ||||
| @@ -1,6 +1,10 @@ | |||||
| from .bert_for_sequence_classification import * # noqa F403 | from .bert_for_sequence_classification import * # noqa F403 | ||||
| from .masked_language_model import * # noqa F403 | from .masked_language_model import * # noqa F403 | ||||
| from .palm_for_text_generation import * # noqa F403 | from .palm_for_text_generation import * # noqa F403 | ||||
| from .sbert_for_nli import * # noqa F403 | |||||
| from .sbert_for_sentence_similarity import * # noqa F403 | from .sbert_for_sentence_similarity import * # noqa F403 | ||||
| from .sbert_for_sentiment_classification import * # noqa F403 | |||||
| from .sbert_for_token_classification import * # noqa F403 | from .sbert_for_token_classification import * # noqa F403 | ||||
| from .sbert_for_zero_shot_classification import * # noqa F403 | from .sbert_for_zero_shot_classification import * # noqa F403 | ||||
| from .space.dialog_intent_prediction_model import * # noqa F403 | |||||
| from .space.dialog_modeling_model import * # noqa F403 | |||||
| @@ -4,8 +4,8 @@ from typing import Any, Dict | |||||
| import json | import json | ||||
| import numpy as np | import numpy as np | ||||
| from modelscope.metainfo import Models | |||||
| from modelscope.utils.constant import Tasks | |||||
| from ...metainfo import Models | |||||
| from ...utils.constant import Tasks | |||||
| from ..base import Model | from ..base import Model | ||||
| from ..builder import MODELS | from ..builder import MODELS | ||||
| @@ -16,16 +16,22 @@ class MaskedLanguageModelBase(Model): | |||||
| super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
| self.model = self.build_model() | self.model = self.build_model() | ||||
| def build_model(): | |||||
| def build_model(self): | |||||
| raise NotImplementedError() | raise NotImplementedError() | ||||
| def train(self): | |||||
| return self.model.train() | |||||
| def eval(self): | |||||
| return self.model.eval() | |||||
| @property | @property | ||||
| def config(self): | def config(self): | ||||
| if hasattr(self.model, 'config'): | if hasattr(self.model, 'config'): | ||||
| return self.model.config | return self.model.config | ||||
| return None | return None | ||||
| def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, np.ndarray]: | |||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, np.ndarray]: | |||||
| """return the result by the model | """return the result by the model | ||||
| Args: | Args: | ||||
| @@ -35,10 +41,10 @@ class MaskedLanguageModelBase(Model): | |||||
| Dict[str, np.ndarray]: results | Dict[str, np.ndarray]: results | ||||
| """ | """ | ||||
| rst = self.model( | rst = self.model( | ||||
| input_ids=inputs['input_ids'], | |||||
| attention_mask=inputs['attention_mask'], | |||||
| token_type_ids=inputs['token_type_ids']) | |||||
| return {'logits': rst['logits'], 'input_ids': inputs['input_ids']} | |||||
| input_ids=input['input_ids'], | |||||
| attention_mask=input['attention_mask'], | |||||
| token_type_ids=input['token_type_ids']) | |||||
| return {'logits': rst['logits'], 'input_ids': input['input_ids']} | |||||
| @MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert) | @MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert) | ||||
| @@ -1,7 +1,7 @@ | |||||
| from typing import Dict | from typing import Dict | ||||
| from modelscope.metainfo import Models | |||||
| from modelscope.utils.constant import Tasks | |||||
| from ...metainfo import Models | |||||
| from ...utils.constant import Tasks | |||||
| from ..base import Model, Tensor | from ..base import Model, Tensor | ||||
| from ..builder import MODELS | from ..builder import MODELS | ||||
| @@ -20,13 +20,18 @@ class PalmForTextGeneration(Model): | |||||
| default loader to load model weights, by default None. | default loader to load model weights, by default None. | ||||
| """ | """ | ||||
| super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
| self.model_dir = model_dir | |||||
| from sofa.models.palm_v2 import PalmForConditionalGeneration, Translator | from sofa.models.palm_v2 import PalmForConditionalGeneration, Translator | ||||
| model = PalmForConditionalGeneration.from_pretrained(model_dir) | model = PalmForConditionalGeneration.from_pretrained(model_dir) | ||||
| self.tokenizer = model.tokenizer | self.tokenizer = model.tokenizer | ||||
| self.generator = Translator(model) | self.generator = Translator(model) | ||||
| def train(self): | |||||
| return self.generator.train() | |||||
| def eval(self): | |||||
| return self.generator.eval() | |||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | ||||
| """return the result by the model | """return the result by the model | ||||
| @@ -0,0 +1,23 @@ | |||||
| from ...metainfo import Models | |||||
| from ...utils.constant import Tasks | |||||
| from ..builder import MODELS | |||||
| from .sbert_for_sequence_classification import \ | |||||
| SbertForSequenceClassificationBase | |||||
| __all__ = ['SbertForNLI'] | |||||
| @MODELS.register_module(Tasks.nli, module_name=Models.structbert) | |||||
| class SbertForNLI(SbertForSequenceClassificationBase): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the text generation model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| model_cls (Optional[Any], optional): model loader, if None, use the | |||||
| default loader to load model weights, by default None. | |||||
| """ | |||||
| super().__init__( | |||||
| model_dir, *args, model_args={'num_labels': 3}, **kwargs) | |||||
| assert self.model.config.num_labels == 3 | |||||
| @@ -1,46 +1,15 @@ | |||||
| import os | |||||
| from typing import Any, Dict | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | |||||
| from sofa import SbertModel | |||||
| from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel | |||||
| from torch import nn | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.utils.constant import Tasks | |||||
| from ..base import Model, Tensor | |||||
| from ...metainfo import Models | |||||
| from ...utils.constant import Tasks | |||||
| from ..builder import MODELS | from ..builder import MODELS | ||||
| from .sbert_for_sequence_classification import \ | |||||
| SbertForSequenceClassificationBase | |||||
| __all__ = ['SbertForSentenceSimilarity'] | __all__ = ['SbertForSentenceSimilarity'] | ||||
| class SbertTextClassifier(SbertPreTrainedModel): | |||||
| def __init__(self, config): | |||||
| super().__init__(config) | |||||
| self.num_labels = config.num_labels | |||||
| self.config = config | |||||
| self.encoder = SbertModel(config, add_pooling_layer=True) | |||||
| self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||||
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |||||
| def forward(self, input_ids=None, token_type_ids=None): | |||||
| outputs = self.encoder( | |||||
| input_ids, | |||||
| token_type_ids=token_type_ids, | |||||
| return_dict=None, | |||||
| ) | |||||
| pooled_output = outputs[1] | |||||
| pooled_output = self.dropout(pooled_output) | |||||
| logits = self.classifier(pooled_output) | |||||
| return logits | |||||
| @MODELS.register_module( | @MODELS.register_module( | ||||
| Tasks.sentence_similarity, module_name=Models.structbert) | Tasks.sentence_similarity, module_name=Models.structbert) | ||||
| class SbertForSentenceSimilarity(Model): | |||||
| class SbertForSentenceSimilarity(SbertForSequenceClassificationBase): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| """initialize the sentence similarity model from the `model_dir` path. | """initialize the sentence similarity model from the `model_dir` path. | ||||
| @@ -50,39 +19,7 @@ class SbertForSentenceSimilarity(Model): | |||||
| model_cls (Optional[Any], optional): model loader, if None, use the | model_cls (Optional[Any], optional): model loader, if None, use the | ||||
| default loader to load model weights, by default None. | default loader to load model weights, by default None. | ||||
| """ | """ | ||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| super().__init__( | |||||
| model_dir, *args, model_args={'num_labels': 2}, **kwargs) | |||||
| self.model_dir = model_dir | self.model_dir = model_dir | ||||
| self.model = SbertTextClassifier.from_pretrained( | |||||
| model_dir, num_labels=2) | |||||
| self.model.eval() | |||||
| self.label_path = os.path.join(self.model_dir, 'label_mapping.json') | |||||
| with open(self.label_path) as f: | |||||
| self.label_mapping = json.load(f) | |||||
| self.id2label = {idx: name for name, idx in self.label_mapping.items()} | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||||
| """return the result by the model | |||||
| Args: | |||||
| input (Dict[str, Any]): the preprocessed data | |||||
| Returns: | |||||
| Dict[str, np.ndarray]: results | |||||
| Example: | |||||
| { | |||||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||||
| } | |||||
| """ | |||||
| input_ids = torch.tensor(input['input_ids'], dtype=torch.long) | |||||
| token_type_ids = torch.tensor( | |||||
| input['token_type_ids'], dtype=torch.long) | |||||
| with torch.no_grad(): | |||||
| logits = self.model(input_ids, token_type_ids) | |||||
| probs = logits.softmax(-1).numpy() | |||||
| pred = logits.argmax(-1).numpy() | |||||
| logits = logits.numpy() | |||||
| res = {'predictions': pred, 'probabilities': probs, 'logits': logits} | |||||
| return res | |||||
| assert self.model.config.num_labels == 2 | |||||
| @@ -0,0 +1,22 @@ | |||||
| from ...metainfo import Models | |||||
| from ...utils.constant import Tasks | |||||
| from ..builder import MODELS | |||||
| from .sbert_for_sequence_classification import \ | |||||
| SbertForSequenceClassificationBase | |||||
| __all__ = ['SbertForSentimentClassification'] | |||||
| @MODELS.register_module( | |||||
| Tasks.sentiment_classification, module_name=Models.structbert) | |||||
| class SbertForSentimentClassification(SbertForSequenceClassificationBase): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the text generation model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| """ | |||||
| super().__init__( | |||||
| model_dir, *args, model_args={'num_labels': 2}, **kwargs) | |||||
| assert self.model.config.num_labels == 2 | |||||
| @@ -0,0 +1,71 @@ | |||||
| import os | |||||
| from typing import Any, Dict | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | |||||
| from sofa.models.sbert.modeling_sbert import SbertModel, SbertPreTrainedModel | |||||
| from torch import nn | |||||
| from ..base import Model | |||||
| class SbertTextClassfier(SbertPreTrainedModel): | |||||
| def __init__(self, config): | |||||
| super().__init__(config) | |||||
| self.num_labels = config.num_labels | |||||
| self.config = config | |||||
| self.encoder = SbertModel(config, add_pooling_layer=True) | |||||
| self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||||
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |||||
| def forward(self, input_ids=None, token_type_ids=None): | |||||
| outputs = self.encoder( | |||||
| input_ids, | |||||
| token_type_ids=token_type_ids, | |||||
| return_dict=None, | |||||
| ) | |||||
| pooled_output = outputs[1] | |||||
| pooled_output = self.dropout(pooled_output) | |||||
| logits = self.classifier(pooled_output) | |||||
| return {'logits': logits} | |||||
| class SbertForSequenceClassificationBase(Model): | |||||
| def __init__(self, model_dir: str, model_args=None, *args, **kwargs): | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| if model_args is None: | |||||
| model_args = {} | |||||
| self.model = SbertTextClassfier.from_pretrained( | |||||
| model_dir, **model_args) | |||||
| self.id2label = {} | |||||
| self.label_path = os.path.join(self.model_dir, 'label_mapping.json') | |||||
| if os.path.exists(self.label_path): | |||||
| with open(self.label_path) as f: | |||||
| self.label_mapping = json.load(f) | |||||
| self.id2label = { | |||||
| idx: name | |||||
| for name, idx in self.label_mapping.items() | |||||
| } | |||||
| def train(self): | |||||
| return self.model.train() | |||||
| def eval(self): | |||||
| return self.model.eval() | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||||
| input_ids = torch.tensor(input['input_ids'], dtype=torch.long) | |||||
| token_type_ids = torch.tensor( | |||||
| input['token_type_ids'], dtype=torch.long) | |||||
| return self.model.forward(input_ids, token_type_ids) | |||||
| def postprocess(self, input, **kwargs): | |||||
| logits = input['logits'] | |||||
| probs = logits.softmax(-1).numpy() | |||||
| pred = logits.argmax(-1).numpy() | |||||
| logits = logits.numpy() | |||||
| res = {'predictions': pred, 'probabilities': probs, 'logits': logits} | |||||
| return res | |||||
| @@ -2,18 +2,17 @@ from typing import Any, Dict, Union | |||||
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| from sofa import SbertConfig, SbertForTokenClassification | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.utils.constant import Tasks | |||||
| from ...metainfo import Models | |||||
| from ...utils.constant import Tasks | |||||
| from ..base import Model, Tensor | from ..base import Model, Tensor | ||||
| from ..builder import MODELS | from ..builder import MODELS | ||||
| __all__ = ['StructBertForTokenClassification'] | |||||
| __all__ = ['SbertForTokenClassification'] | |||||
| @MODELS.register_module(Tasks.word_segmentation, module_name=Models.structbert) | @MODELS.register_module(Tasks.word_segmentation, module_name=Models.structbert) | ||||
| class StructBertForTokenClassification(Model): | |||||
| class SbertForTokenClassification(Model): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| """initialize the word segmentation model from the `model_dir` path. | """initialize the word segmentation model from the `model_dir` path. | ||||
| @@ -25,9 +24,16 @@ class StructBertForTokenClassification(Model): | |||||
| """ | """ | ||||
| super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
| self.model_dir = model_dir | self.model_dir = model_dir | ||||
| self.model = SbertForTokenClassification.from_pretrained( | |||||
| import sofa | |||||
| self.model = sofa.SbertForTokenClassification.from_pretrained( | |||||
| self.model_dir) | self.model_dir) | ||||
| self.config = SbertConfig.from_pretrained(self.model_dir) | |||||
| self.config = sofa.SbertConfig.from_pretrained(self.model_dir) | |||||
| def train(self): | |||||
| return self.model.train() | |||||
| def eval(self): | |||||
| return self.model.eval() | |||||
| def forward(self, input: Dict[str, | def forward(self, input: Dict[str, | ||||
| Any]) -> Dict[str, Union[str, np.ndarray]]: | Any]) -> Dict[str, Union[str, np.ndarray]]: | ||||
| @@ -46,10 +52,12 @@ class StructBertForTokenClassification(Model): | |||||
| } | } | ||||
| """ | """ | ||||
| input_ids = torch.tensor(input['input_ids']).unsqueeze(0) | input_ids = torch.tensor(input['input_ids']).unsqueeze(0) | ||||
| output = self.model(input_ids) | |||||
| logits = output.logits | |||||
| return {**self.model(input_ids), 'text': input['text']} | |||||
| def postprocess(self, input: Dict[str, Tensor], | |||||
| **kwargs) -> Dict[str, Tensor]: | |||||
| logits = input['logits'] | |||||
| pred = torch.argmax(logits[0], dim=-1) | pred = torch.argmax(logits[0], dim=-1) | ||||
| pred = pred.numpy() | pred = pred.numpy() | ||||
| rst = {'predictions': pred, 'logits': logits, 'text': input['text']} | rst = {'predictions': pred, 'logits': logits, 'text': input['text']} | ||||
| return rst | return rst | ||||
| @@ -0,0 +1,81 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| from typing import Any, Dict | |||||
| from ....metainfo import Models | |||||
| from ....preprocessors.space.fields.intent_field import IntentBPETextField | |||||
| from ....trainers.nlp.space.trainer.intent_trainer import IntentTrainer | |||||
| from ....utils.config import Config | |||||
| from ....utils.constant import ModelFile, Tasks | |||||
| from ...base import Model, Tensor | |||||
| from ...builder import MODELS | |||||
| from .model.generator import Generator | |||||
| from .model.model_base import SpaceModelBase | |||||
| __all__ = ['SpaceForDialogIntent'] | |||||
| @MODELS.register_module( | |||||
| Tasks.dialog_intent_prediction, module_name=Models.space) | |||||
| class SpaceForDialogIntent(Model): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the test generation model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| """ | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| self.model_dir = model_dir | |||||
| self.config = kwargs.pop( | |||||
| 'config', | |||||
| Config.from_file( | |||||
| os.path.join(self.model_dir, ModelFile.CONFIGURATION))) | |||||
| self.text_field = kwargs.pop( | |||||
| 'text_field', | |||||
| IntentBPETextField(self.model_dir, config=self.config)) | |||||
| self.generator = Generator.create(self.config, reader=self.text_field) | |||||
| self.model = SpaceModelBase.create( | |||||
| model_dir=model_dir, | |||||
| config=self.config, | |||||
| reader=self.text_field, | |||||
| generator=self.generator) | |||||
| def to_tensor(array): | |||||
| """ | |||||
| numpy array -> tensor | |||||
| """ | |||||
| import torch | |||||
| array = torch.tensor(array) | |||||
| return array.cuda() if self.config.use_gpu else array | |||||
| self.trainer = IntentTrainer( | |||||
| model=self.model, | |||||
| to_tensor=to_tensor, | |||||
| config=self.config, | |||||
| reader=self.text_field) | |||||
| self.trainer.load() | |||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||||
| """return the result by the model | |||||
| Args: | |||||
| input (Dict[str, Any]): the preprocessed data | |||||
| Returns: | |||||
| Dict[str, np.ndarray]: results | |||||
| Example: | |||||
| { | |||||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||||
| } | |||||
| """ | |||||
| import numpy as np | |||||
| pred = self.trainer.forward(input) | |||||
| pred = np.squeeze(pred[0], 0) | |||||
| return {'pred': pred} | |||||
| @@ -0,0 +1,82 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| from typing import Any, Dict, Optional | |||||
| from ....metainfo import Models | |||||
| from ....preprocessors.space.fields.gen_field import MultiWOZBPETextField | |||||
| from ....trainers.nlp.space.trainer.gen_trainer import MultiWOZTrainer | |||||
| from ....utils.config import Config | |||||
| from ....utils.constant import ModelFile, Tasks | |||||
| from ...base import Model, Tensor | |||||
| from ...builder import MODELS | |||||
| from .model.generator import Generator | |||||
| from .model.model_base import SpaceModelBase | |||||
| __all__ = ['SpaceForDialogModeling'] | |||||
| @MODELS.register_module(Tasks.dialog_modeling, module_name=Models.space) | |||||
| class SpaceForDialogModeling(Model): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the test generation model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| """ | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| self.model_dir = model_dir | |||||
| self.config = kwargs.pop( | |||||
| 'config', | |||||
| Config.from_file( | |||||
| os.path.join(self.model_dir, ModelFile.CONFIGURATION))) | |||||
| self.text_field = kwargs.pop( | |||||
| 'text_field', | |||||
| MultiWOZBPETextField(self.model_dir, config=self.config)) | |||||
| self.generator = Generator.create(self.config, reader=self.text_field) | |||||
| self.model = SpaceModelBase.create( | |||||
| model_dir=model_dir, | |||||
| config=self.config, | |||||
| reader=self.text_field, | |||||
| generator=self.generator) | |||||
| def to_tensor(array): | |||||
| """ | |||||
| numpy array -> tensor | |||||
| """ | |||||
| import torch | |||||
| array = torch.tensor(array) | |||||
| return array.cuda() if self.config.use_gpu else array | |||||
| self.trainer = MultiWOZTrainer( | |||||
| model=self.model, | |||||
| to_tensor=to_tensor, | |||||
| config=self.config, | |||||
| reader=self.text_field, | |||||
| evaluator=None) | |||||
| self.trainer.load() | |||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||||
| """return the result by the model | |||||
| Args: | |||||
| input (Dict[str, Any]): the preprocessed data | |||||
| Returns: | |||||
| Dict[str, np.ndarray]: results | |||||
| Example: | |||||
| { | |||||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||||
| } | |||||
| """ | |||||
| turn = {'user': input['user']} | |||||
| old_pv_turn = input['history'] | |||||
| pv_turn = self.trainer.forward(turn=turn, old_pv_turn=old_pv_turn) | |||||
| return pv_turn | |||||
| @@ -0,0 +1,3 @@ | |||||
| from .gen_unified_transformer import GenUnifiedTransformer | |||||
| from .intent_unified_transformer import IntentUnifiedTransformer | |||||
| from .unified_transformer import UnifiedTransformer | |||||
| @@ -0,0 +1,283 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import torch | |||||
| from .unified_transformer import UnifiedTransformer | |||||
| class GenUnifiedTransformer(UnifiedTransformer): | |||||
| """ | |||||
| Implement generation unified transformer. | |||||
| """ | |||||
| def __init__(self, model_dir, config, reader, generator): | |||||
| super(GenUnifiedTransformer, self).__init__(model_dir, config, reader, | |||||
| generator) | |||||
| self.understand = config.BPETextField.understand | |||||
| if self.use_gpu: | |||||
| self.cuda() | |||||
| return | |||||
| def _forward(self, inputs, is_training, with_label): | |||||
| """ Real forward process of model in different mode(train/test). """ | |||||
| def cat(x, y, dim=1): | |||||
| return torch.cat([x, y], dim=dim) | |||||
| outputs = {} | |||||
| if self.understand or self.policy: | |||||
| if self.understand: | |||||
| prompt_token = inputs['understand_token'] | |||||
| prompt_mask = inputs['understand_mask'] | |||||
| if self.policy: | |||||
| prompt_token = cat(prompt_token, inputs['policy_token']) | |||||
| prompt_mask = cat(prompt_mask, inputs['policy_mask']) | |||||
| else: | |||||
| prompt_token = inputs['policy_token'] | |||||
| prompt_mask = inputs['policy_mask'] | |||||
| enc_embed, dec_embed, prompt_embed = self._encoder_prompt_decoder_network( | |||||
| src_token=inputs['src_token'], | |||||
| src_mask=inputs['src_mask'], | |||||
| tgt_token=inputs['tgt_token'][:, :-1], | |||||
| tgt_mask=inputs['tgt_mask'][:, :-1], | |||||
| prompt_token=prompt_token, | |||||
| prompt_mask=prompt_mask, | |||||
| src_pos=inputs['src_pos'], | |||||
| src_type=inputs['src_type'], | |||||
| src_turn=inputs['src_turn'], | |||||
| tgt_pos=inputs['tgt_pos'][:, :-1], | |||||
| tgt_type=inputs['tgt_type'][:, :-1], | |||||
| tgt_turn=inputs['tgt_turn'][:, :-1]) | |||||
| else: | |||||
| enc_embed, dec_embed = self._encoder_decoder_network( | |||||
| src_token=inputs['src_token'], | |||||
| src_mask=inputs['src_mask'], | |||||
| tgt_token=inputs['tgt_token'][:, :-1], | |||||
| tgt_mask=inputs['tgt_mask'][:, :-1], | |||||
| src_pos=inputs['src_pos'], | |||||
| src_type=inputs['src_type'], | |||||
| src_turn=inputs['src_turn'], | |||||
| tgt_pos=inputs['tgt_pos'][:, :-1], | |||||
| tgt_type=inputs['tgt_type'][:, :-1], | |||||
| tgt_turn=inputs['tgt_turn'][:, :-1]) | |||||
| outputs['dec_probs'] = self._dec_head(dec_embed=dec_embed) | |||||
| return outputs | |||||
| def _collect_metrics(self, inputs, outputs, with_label, data_file): | |||||
| metrics = {} | |||||
| loss = 0. | |||||
| label = inputs['tgt_token'][:, 1:] | |||||
| token_num = torch.sum(torch.sum(inputs['tgt_mask'], dim=1) - 1) | |||||
| nll = self.nll_loss( | |||||
| torch.log(outputs['dec_probs'] + 1e-12).permute(0, 2, 1), label) | |||||
| nll = torch.sum(nll, dim=1) | |||||
| token_nll = torch.sum(nll) / token_num | |||||
| nll = torch.mean(nll) | |||||
| metrics['nll'] = nll | |||||
| metrics['token_nll'] = token_nll | |||||
| metrics['token_num'] = token_num | |||||
| loss = loss + (token_nll if self.token_loss else nll) | |||||
| metrics['loss'] = loss | |||||
| if self.gpu > 1: | |||||
| return nll, token_nll, token_num | |||||
| else: | |||||
| return metrics | |||||
| def _optimize(self, loss, do_update=False, optimizer=None): | |||||
| """ Optimize loss function and update model. """ | |||||
| assert optimizer is not None | |||||
| if self.gradient_accumulation_steps > 1: | |||||
| loss = loss / self.gradient_accumulation_steps | |||||
| loss.backward() | |||||
| if self.grad_clip is not None and self.grad_clip > 0: | |||||
| torch.nn.utils.clip_grad_norm_( | |||||
| parameters=self.parameters(), max_norm=self.grad_clip) | |||||
| if do_update: | |||||
| optimizer.step() | |||||
| optimizer.zero_grad() | |||||
| return | |||||
| def _init_state(self, | |||||
| src_token, | |||||
| src_mask, | |||||
| src_pos=None, | |||||
| src_type=None, | |||||
| src_turn=None): | |||||
| """ Initialize decode state. """ | |||||
| state = {} | |||||
| batch_size = src_token.shape[0] | |||||
| src_embed = self.embedder(src_token, src_pos, src_type, src_turn) | |||||
| src_embed = self.embed_layer_norm(src_embed) | |||||
| mask = self._create_mask(src_mask, append_head=False) | |||||
| enc_out = src_embed | |||||
| cache = {} | |||||
| for _l, layer in enumerate(self.layers): | |||||
| cache[f'layer_{_l}'] = {} | |||||
| enc_out = layer(enc_out, mask, cache[f'layer_{_l}']) | |||||
| state['cache'] = cache | |||||
| state['mask'] = mask[:, :1] | |||||
| state['batch_size'] = batch_size | |||||
| shape = [batch_size, 1, 1] | |||||
| state['pred_mask'] = torch.ones(shape, dtype=torch.float32) | |||||
| state['pred_pos'] = torch.zeros(shape, dtype=torch.int64) | |||||
| state['pred_type'] = torch.zeros(shape, dtype=torch.int64) | |||||
| state['pred_turn'] = torch.zeros(shape, dtype=torch.int64) | |||||
| if self.use_gpu: | |||||
| state['pred_mask'] = state['pred_mask'].cuda() | |||||
| state['pred_pos'] = state['pred_pos'].cuda() | |||||
| state['pred_type'] = state['pred_type'].cuda() | |||||
| state['pred_turn'] = state['pred_turn'].cuda() | |||||
| return state | |||||
| def _init_prompt_state(self, | |||||
| src_token, | |||||
| src_mask, | |||||
| prompt_token, | |||||
| prompt_mask, | |||||
| src_pos=None, | |||||
| src_type=None, | |||||
| src_turn=None, | |||||
| prompt_pos=None, | |||||
| prompt_type=None, | |||||
| prompt_turn=None): | |||||
| """ Initialize decode state. """ | |||||
| state = {} | |||||
| batch_size = src_token.shape[0] | |||||
| src_embed = self.embedder(src_token, src_pos, src_type, src_turn) | |||||
| prompt_embed = self.embedder(prompt_token, prompt_pos, prompt_type, | |||||
| prompt_turn) | |||||
| embed = torch.cat([src_embed, prompt_embed], dim=1) | |||||
| embed = self.embed_layer_norm(embed) | |||||
| enc_out = embed | |||||
| enc_mask = self._create_mask(src_mask, auto_regressive=False) | |||||
| dec_mask = self._create_mask(prompt_mask, auto_regressive=True) | |||||
| mask = self._join_mask(enc_mask, dec_mask) | |||||
| cache = {} | |||||
| for _l, layer in enumerate(self.layers): | |||||
| cache[f'layer_{_l}'] = {} | |||||
| enc_out = layer(enc_out, mask, cache[f'layer_{_l}']) | |||||
| state['cache'] = cache | |||||
| state['mask'] = mask[:, -1:] # state["mask"] = mask[:, :1] | |||||
| state['batch_size'] = batch_size | |||||
| shape = [batch_size, 1, 1] | |||||
| state['pred_mask'] = torch.ones(shape, dtype=torch.float32) | |||||
| state['pred_pos'] = torch.zeros(shape, dtype=torch.int64) | |||||
| state['pred_type'] = torch.zeros(shape, dtype=torch.int64) | |||||
| state['pred_turn'] = torch.zeros(shape, dtype=torch.int64) | |||||
| if self.use_gpu: | |||||
| state['pred_mask'] = state['pred_mask'].cuda() | |||||
| state['pred_pos'] = state['pred_pos'].cuda() | |||||
| state['pred_type'] = state['pred_type'].cuda() | |||||
| state['pred_turn'] = state['pred_turn'].cuda() | |||||
| return state | |||||
| def _decode(self, state): | |||||
| """ Decoding one time stamp. """ | |||||
| # shape: [batch_size, 1, seq_len] | |||||
| mask = state['mask'] | |||||
| # shape: [batch_size, 1, 1] | |||||
| pred_token = state['pred_token'] | |||||
| pred_mask = state['pred_mask'] | |||||
| pred_pos = state['pred_pos'] | |||||
| pred_type = state['pred_type'] | |||||
| pred_turn = state['pred_turn'] | |||||
| # list of shape(len: num_layers): [batch_size, seq_len, hidden_dim] | |||||
| cache = state['cache'] | |||||
| pred_embed = self.embedder(pred_token, pred_pos, pred_type, | |||||
| pred_turn).squeeze(-2) | |||||
| pred_embed = self.embed_layer_norm(pred_embed) | |||||
| # shape: [batch_size, 1, seq_len + 1] | |||||
| mask = torch.cat([mask, 1 - pred_mask], dim=2) | |||||
| # shape: [batch_size, 1, hidden_dim] | |||||
| for _l, layer in enumerate(self.layers): | |||||
| pred_embed = layer(pred_embed, mask, cache[f'layer_{_l}']) | |||||
| # shape: [batch_size, vocab_size] | |||||
| pred_probs = self._dec_head(dec_embed=pred_embed[:, 0]) | |||||
| pred_logits = torch.log(pred_probs) | |||||
| state['mask'] = mask | |||||
| return pred_logits, state | |||||
| def _infer(self, | |||||
| inputs, | |||||
| start_id=None, | |||||
| eos_id=None, | |||||
| max_gen_len=None, | |||||
| prev_input=None): | |||||
| """ Real inference process of model. """ | |||||
| def cat(x, y, dim=1): | |||||
| return torch.cat([x, y], dim=dim) | |||||
| # Initial decode state. | |||||
| if self.understand or self.policy: | |||||
| if self.understand: | |||||
| prompt_token = inputs['understand_token'] | |||||
| prompt_mask = inputs['understand_mask'] | |||||
| if self.policy: | |||||
| prompt_token = cat(prompt_token, inputs['policy_token']) | |||||
| prompt_mask = cat(prompt_mask, inputs['policy_mask']) | |||||
| else: | |||||
| prompt_token = inputs['policy_token'] | |||||
| prompt_mask = inputs['policy_mask'] | |||||
| state = self._init_prompt_state( | |||||
| src_token=inputs['src_token'], | |||||
| src_mask=inputs['src_mask'], | |||||
| prompt_token=prompt_token, | |||||
| prompt_mask=prompt_mask, | |||||
| src_pos=inputs['src_pos'], | |||||
| src_type=inputs['src_type'], | |||||
| src_turn=inputs['src_turn']) | |||||
| else: | |||||
| state = self._init_state( | |||||
| src_token=inputs['src_token'], | |||||
| src_mask=inputs['src_mask'], | |||||
| src_pos=inputs['src_pos'], | |||||
| src_type=inputs['src_type'], | |||||
| src_turn=inputs['src_turn']) | |||||
| # Generation process. | |||||
| gen_results = self.generator( | |||||
| step_fn=self._decode, | |||||
| state=state, | |||||
| start_id=start_id, | |||||
| eos_id=eos_id, | |||||
| max_gen_len=max_gen_len, | |||||
| prev_input=prev_input) | |||||
| outputs = gen_results['preds'] | |||||
| return outputs | |||||
| GenUnifiedTransformer.register('GenUnifiedTransformer') | |||||
| @@ -0,0 +1,287 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import math | |||||
| import numpy as np | |||||
| import torch | |||||
| def repeat(var, times): | |||||
| if isinstance(var, list): | |||||
| return [repeat(x, times) for x in var] | |||||
| elif isinstance(var, dict): | |||||
| return {k: repeat(v, times) for k, v in var.items()} | |||||
| elif isinstance(var, torch.Tensor): | |||||
| var = var.unsqueeze(1) | |||||
| expand_times = [1] * len(var.shape) | |||||
| expand_times[1] = times | |||||
| dtype = var.dtype | |||||
| var = var.float() | |||||
| var = var.repeat(*expand_times) | |||||
| shape = [var.shape[0] * var.shape[1]] + list(var.shape[2:]) | |||||
| var = var.reshape(*shape) | |||||
| var = torch.tensor(var, dtype=dtype) | |||||
| return var | |||||
| else: | |||||
| return var | |||||
| def gather(var, idx): | |||||
| if isinstance(var, list): | |||||
| return [gather(x, idx) for x in var] | |||||
| elif isinstance(var, dict): | |||||
| return {k: gather(v, idx) for k, v in var.items()} | |||||
| elif isinstance(var, torch.Tensor): | |||||
| out = var.index_select(dim=0, index=idx) | |||||
| return out | |||||
| else: | |||||
| return var | |||||
| class Generator(object): | |||||
| """ Genrator class. """ | |||||
| _registry = dict() | |||||
| @classmethod | |||||
| def register(cls, name): | |||||
| Generator._registry[name] = cls | |||||
| return | |||||
| @staticmethod | |||||
| def by_name(name): | |||||
| return Generator._registry[name] | |||||
| @staticmethod | |||||
| def create(config, *args, **kwargs): | |||||
| """ Create generator. """ | |||||
| generator_cls = Generator.by_name(config.Generator.generator) | |||||
| return generator_cls(config, *args, **kwargs) | |||||
| def __init__(self, config, reader): | |||||
| self.vocab_size = reader.vocab_size | |||||
| self.bos_id = reader.bos_id | |||||
| self.eos_id = reader.eos_id | |||||
| self.unk_id = reader.unk_id | |||||
| self.pad_id = reader.pad_id | |||||
| self.min_gen_len = config.Generator.min_gen_len | |||||
| self.max_gen_len = config.Generator.max_gen_len | |||||
| self.use_gpu = config.use_gpu | |||||
| assert 1 <= self.min_gen_len <= self.max_gen_len | |||||
| return | |||||
| def __call__(self, step_fn, state): | |||||
| """ | |||||
| Running generation. | |||||
| @param : step_fn : decoding one step | |||||
| @type : function | |||||
| @param : state : initial state | |||||
| @type : dict | |||||
| """ | |||||
| raise NotImplementedError | |||||
| class BeamSearch(Generator): | |||||
| """ BeamSearch generator. """ | |||||
| def __init__(self, config, reader): | |||||
| super().__init__(config, reader) | |||||
| self.beam_size = config.Generator.beam_size | |||||
| self.length_average = config.Generator.length_average | |||||
| self.length_penalty = config.Generator.length_penalty | |||||
| self.ignore_unk = config.Generator.ignore_unk | |||||
| return | |||||
| def __call__(self, | |||||
| step_fn, | |||||
| state, | |||||
| start_id=None, | |||||
| eos_id=None, | |||||
| max_gen_len=None, | |||||
| prev_input=None): | |||||
| """ | |||||
| Running beam search. | |||||
| @param : step_fn : decoding one step | |||||
| @type : function | |||||
| @param : state : initial state | |||||
| @type : dict | |||||
| """ | |||||
| if prev_input is not None: | |||||
| if isinstance(prev_input, list): | |||||
| length = max(list(map(lambda x: len(x), prev_input))) | |||||
| prev_input_numpy = np.full((len(prev_input), length), | |||||
| self.pad_id) | |||||
| for i, x in enumerate(prev_input): | |||||
| prev_input_numpy[i, :len(x)] = x | |||||
| prev_input_tensor = torch.from_numpy(prev_input_numpy) | |||||
| if self.use_gpu: | |||||
| prev_input_tensor = prev_input_tensor.cuda() | |||||
| for i in range(length): | |||||
| state['pred_token'] = prev_input_tensor[:, i].unsqueeze( | |||||
| -1).unsqueeze(-1) | |||||
| if i != 0: | |||||
| state['pred_mask'] = torch.not_equal( | |||||
| state['pred_token'], self.pad_id).float() | |||||
| state['pred_pos'] = state['pred_pos'] + state[ | |||||
| 'pred_mask'].int() | |||||
| _, state = step_fn(state) | |||||
| else: | |||||
| assert isinstance(prev_input, torch.Tensor) | |||||
| for i, input in enumerate(prev_input): | |||||
| state['pred_token'] = input.expand(1, 1, 1) | |||||
| if i != 0: | |||||
| state['pred_mask'] = torch.not_equal( | |||||
| state['pred_token'], self.pad_id).float() | |||||
| state['pred_pos'] = state['pred_pos'] + 1 | |||||
| _, state = step_fn(state) | |||||
| batch_size = state['batch_size'] | |||||
| beam_size = self.beam_size | |||||
| # shape: [batch_size, 1] | |||||
| pos_index = torch.arange( | |||||
| 0, batch_size, 1, dtype=torch.int64) * beam_size | |||||
| pos_index = pos_index.unsqueeze(1) | |||||
| # shape: [batch_size, beam_size, 1] | |||||
| if start_id is None: | |||||
| start_id = self.bos_id | |||||
| if eos_id is None: | |||||
| eos_id = self.eos_id | |||||
| predictions = torch.ones([batch_size, beam_size, 1], | |||||
| dtype=torch.int64) * start_id | |||||
| if self.use_gpu: | |||||
| pos_index = pos_index.cuda() | |||||
| predictions = predictions.cuda() | |||||
| # initial input (start_id) | |||||
| state['pred_token'] = predictions[:, :1] | |||||
| if prev_input is not None: | |||||
| state['pred_mask'] = torch.not_equal(state['pred_token'], | |||||
| self.pad_id).float() | |||||
| state['pred_pos'] = state['pred_pos'] + 1 | |||||
| # shape: [batch_size, vocab_size] | |||||
| scores, state = step_fn(state) | |||||
| unk_penalty = np.zeros(self.vocab_size, dtype='float32') | |||||
| unk_penalty[self.unk_id] = -1e10 | |||||
| unk_penalty = torch.from_numpy(unk_penalty) | |||||
| eos_penalty = np.zeros(self.vocab_size, dtype='float32') | |||||
| eos_penalty[eos_id] = -1e10 | |||||
| eos_penalty = torch.from_numpy(eos_penalty) | |||||
| scores_after_end = np.full(self.vocab_size, -1e10, dtype='float32') | |||||
| scores_after_end[ | |||||
| self. | |||||
| pad_id] = 0 # we want <pad> is generated after <eos>,so maximum log(p(<pad>)) is (0) | |||||
| scores_after_end = torch.from_numpy(scores_after_end) | |||||
| if self.use_gpu: | |||||
| unk_penalty = unk_penalty.cuda() | |||||
| eos_penalty = eos_penalty.cuda() | |||||
| scores_after_end = scores_after_end.cuda() | |||||
| if self.ignore_unk: | |||||
| scores = scores + unk_penalty | |||||
| scores = scores + eos_penalty | |||||
| # shape: [batch_size, beam_size] | |||||
| sequence_scores, preds = torch.topk(scores, self.beam_size) | |||||
| predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2) | |||||
| state = repeat(state, beam_size) | |||||
| if max_gen_len is None: | |||||
| max_gen_len = self.max_gen_len | |||||
| for step in range(2, max_gen_len + 1): | |||||
| pre_ids = predictions[:, :, -1:] | |||||
| state['pred_token'] = pre_ids.reshape(batch_size * beam_size, 1, 1) | |||||
| state['pred_mask'] = torch.not_equal(state['pred_token'], | |||||
| self.pad_id).float() | |||||
| state['pred_pos'] = state['pred_pos'] + 1 | |||||
| scores, state = step_fn(state) | |||||
| # Generate next | |||||
| # scores shape: [batch_size * beam_size, vocab_size] | |||||
| if self.ignore_unk: | |||||
| scores = scores + unk_penalty | |||||
| if step <= self.min_gen_len: | |||||
| scores = scores + eos_penalty | |||||
| # scores shape: [batch_size, beam_size, vocab_size] | |||||
| scores = scores.reshape(batch_size, beam_size, self.vocab_size) | |||||
| # previous token is [PAD] or [EOS] | |||||
| pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \ | |||||
| (1 - torch.not_equal(pre_ids, self.pad_id).float()) | |||||
| scores = scores * (1 - pre_eos_mask) + pre_eos_mask.repeat( | |||||
| 1, 1, self.vocab_size) * scores_after_end | |||||
| if self.length_average: | |||||
| scaled_value = \ | |||||
| pre_eos_mask + (1 - pre_eos_mask) * (1 - 1 / step) | |||||
| sequence_scores = sequence_scores.unsqueeze(2) * scaled_value | |||||
| scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 / step) | |||||
| scores = scores * scaled_value | |||||
| elif self.length_penalty >= 0.0: | |||||
| scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \ | |||||
| (math.pow((4 + step) / (5 + step), self.length_penalty)) | |||||
| sequence_scores = scaled_value * sequence_scores | |||||
| scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \ | |||||
| (math.pow(1 / (5 + step), self.length_penalty)) | |||||
| scores = scores * scaled_value | |||||
| scores = scores + sequence_scores.unsqueeze(-1) | |||||
| scores = scores.reshape(batch_size, beam_size * self.vocab_size) | |||||
| topk_scores, topk_indices = torch.topk(scores, beam_size) | |||||
| # topk_indices: [batch_size, beam_size * self.vocab_size] (already reshaped) | |||||
| parent_idx = topk_indices.floor_divide(self.vocab_size) | |||||
| preds = topk_indices % self.vocab_size | |||||
| # Gather state / sequence_scores | |||||
| parent_idx = parent_idx + pos_index | |||||
| parent_idx = parent_idx.reshape(batch_size * beam_size) | |||||
| state = gather(state, parent_idx) | |||||
| sequence_scores = topk_scores | |||||
| predictions = predictions.reshape(batch_size * beam_size, step) | |||||
| predictions = gather(predictions, parent_idx) | |||||
| predictions = predictions.reshape(batch_size, beam_size, step) | |||||
| predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2) | |||||
| # The last token should be <eos> or <pad> | |||||
| pre_ids = predictions[:, :, -1] | |||||
| pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \ | |||||
| (1 - torch.not_equal(pre_ids, self.pad_id).float()) | |||||
| sequence_scores = sequence_scores * pre_eos_mask + ( | |||||
| 1 - pre_eos_mask) * (-1e10) | |||||
| # first get ascending ordered index,then sort "predictions" and "sequence_scores" | |||||
| indices = torch.argsort(sequence_scores, dim=1) | |||||
| indices = indices + pos_index | |||||
| indices = indices.reshape(-1) | |||||
| sequence_scores = sequence_scores.reshape(batch_size * beam_size) | |||||
| predictions = predictions.reshape(batch_size * beam_size, -1) | |||||
| sequence_scores = gather(sequence_scores, indices) | |||||
| predictions = gather(predictions, indices) | |||||
| sequence_scores = sequence_scores.reshape(batch_size, beam_size) | |||||
| predictions = predictions.reshape(batch_size, beam_size, -1) | |||||
| results = { | |||||
| 'preds': predictions[:, -1], | |||||
| 'scores': sequence_scores[:, -1] | |||||
| } | |||||
| return results | |||||
| BeamSearch.register('BeamSearch') | |||||
| @@ -0,0 +1,197 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from .....utils.nlp.space.criterions import compute_kl_loss | |||||
| from .unified_transformer import UnifiedTransformer | |||||
| class IntentUnifiedTransformer(UnifiedTransformer): | |||||
| """ | |||||
| Implement intent unified transformer. | |||||
| """ | |||||
| def __init__(self, model_dir, config, reader, generator): | |||||
| super(IntentUnifiedTransformer, self).__init__(model_dir, config, | |||||
| reader, generator) | |||||
| self.example = config.Model.example | |||||
| self.num_intent = config.Model.num_intent | |||||
| self.with_rdrop = config.Model.with_rdrop | |||||
| self.kl_ratio = config.Model.kl_ratio | |||||
| self.loss_fct = nn.CrossEntropyLoss() | |||||
| if self.example: | |||||
| self.loss_fct = nn.NLLLoss() | |||||
| else: | |||||
| self.intent_classifier = nn.Linear(self.hidden_dim, | |||||
| self.num_intent) | |||||
| self.loss_fct = nn.CrossEntropyLoss() | |||||
| if self.use_gpu: | |||||
| self.cuda() | |||||
| return | |||||
| def _forward(self, inputs, is_training, with_label): | |||||
| """ Real forward process of model in different mode(train/test). """ | |||||
| def aug(v): | |||||
| assert isinstance(v, torch.Tensor) | |||||
| return torch.cat([v, v], dim=0) | |||||
| outputs = {} | |||||
| if self.with_mlm: | |||||
| mlm_embed = self._encoder_network( | |||||
| input_token=inputs['mlm_token'], | |||||
| input_mask=inputs['src_mask'], | |||||
| input_pos=inputs['src_pos'], | |||||
| input_type=inputs['src_type'], | |||||
| input_turn=inputs['src_turn']) | |||||
| outputs['mlm_probs'] = self._mlm_head(mlm_embed=mlm_embed) | |||||
| if self.with_rdrop or self.with_contrastive: | |||||
| enc_embed, dec_embed = self._encoder_decoder_network( | |||||
| src_token=aug(inputs['src_token']), | |||||
| src_mask=aug(inputs['src_mask']), | |||||
| tgt_token=aug(inputs['tgt_token']), | |||||
| tgt_mask=aug(inputs['tgt_mask']), | |||||
| src_pos=aug(inputs['src_pos']), | |||||
| src_type=aug(inputs['src_type']), | |||||
| src_turn=aug(inputs['src_turn'])) | |||||
| else: | |||||
| enc_embed, dec_embed = self._encoder_decoder_network( | |||||
| src_token=inputs['src_token'], | |||||
| src_mask=inputs['src_mask'], | |||||
| tgt_token=inputs['tgt_token'], | |||||
| tgt_mask=inputs['tgt_mask'], | |||||
| src_pos=inputs['src_pos'], | |||||
| src_type=inputs['src_type'], | |||||
| src_turn=inputs['src_turn']) | |||||
| features = dec_embed[:, -1] | |||||
| features = self.pooler(features) if self.with_pool else features | |||||
| if self.example: | |||||
| assert not self.with_rdrop | |||||
| ex_enc_embed, ex_dec_embed = self._encoder_decoder_network( | |||||
| src_token=inputs['example_src_token'], | |||||
| src_mask=inputs['example_src_mask'], | |||||
| tgt_token=inputs['example_tgt_token'], | |||||
| tgt_mask=inputs['example_tgt_mask'], | |||||
| src_pos=inputs['example_src_pos'], | |||||
| src_type=inputs['example_src_type'], | |||||
| src_turn=inputs['example_src_turn']) | |||||
| ex_features = ex_dec_embed[:, -1] | |||||
| ex_features = self.pooler( | |||||
| ex_features) if self.with_pool else ex_features | |||||
| probs = self.softmax(features.mm(ex_features.t())) | |||||
| example_intent = inputs['example_intent'].unsqueeze(0) | |||||
| intent_probs = torch.zeros(probs.size(0), self.num_intent) | |||||
| intent_probs = intent_probs.cuda( | |||||
| ) if self.use_gpu else intent_probs | |||||
| intent_probs = intent_probs.scatter_add( | |||||
| -1, example_intent.repeat(probs.size(0), 1), probs) | |||||
| outputs['intent_probs'] = intent_probs | |||||
| else: | |||||
| intent_logits = self.intent_classifier(features) | |||||
| outputs['intent_logits'] = intent_logits | |||||
| if self.with_contrastive: | |||||
| features = features if self.with_pool else self.pooler(features) | |||||
| batch_size = features.size(0) // 2 | |||||
| features = \ | |||||
| torch.cat( | |||||
| [features[:batch_size].unsqueeze(1), features[batch_size:].unsqueeze(1)], | |||||
| dim=1 | |||||
| ) | |||||
| features = F.normalize(features, dim=-1, p=2) | |||||
| outputs['features'] = features | |||||
| return outputs | |||||
| def _collect_metrics(self, inputs, outputs, with_label, data_file): | |||||
| metrics = {} | |||||
| batch_size = inputs['src_token'].size(0) | |||||
| intent_label = torch.cat([inputs['intent_label'], inputs['intent_label']], dim=0) \ | |||||
| if self.with_rdrop or self.with_contrastive else inputs['intent_label'] | |||||
| if self.example: | |||||
| intent_loss = self.loss_fct( | |||||
| torch.log(outputs['intent_probs'] + 1e-12).view( | |||||
| -1, self.num_intent), intent_label.type(torch.long)) | |||||
| else: | |||||
| intent_loss = self.loss_fct( | |||||
| outputs['intent_logits'].view(-1, self.num_intent), | |||||
| intent_label.type(torch.long)) | |||||
| metrics['intent_loss'] = intent_loss | |||||
| loss = intent_loss | |||||
| if self.with_mlm: | |||||
| mlm_num = torch.sum(torch.sum(inputs['mlm_mask'], dim=1)) | |||||
| mlm = self.nll_loss( | |||||
| torch.log(outputs['mlm_probs'] + 1e-12).permute(0, 2, 1), | |||||
| inputs['mlm_label']) | |||||
| mlm = torch.sum(mlm, dim=1) | |||||
| token_mlm = torch.sum(mlm) / mlm_num | |||||
| mlm = torch.mean(mlm) | |||||
| metrics['mlm'] = mlm | |||||
| metrics['token_mlm'] = token_mlm | |||||
| metrics['mlm_num'] = mlm_num | |||||
| loss = loss + (token_mlm | |||||
| if self.token_loss else mlm) * self.mlm_ratio | |||||
| else: | |||||
| mlm, token_mlm, mlm_num = None, None, None | |||||
| if self.with_rdrop: | |||||
| kl = compute_kl_loss( | |||||
| p=outputs['intent_logits'][:batch_size], | |||||
| q=outputs['intent_logits'][batch_size:]) | |||||
| metrics['kl'] = kl | |||||
| loss = loss + kl * self.kl_ratio | |||||
| else: | |||||
| kl = None | |||||
| if self.with_contrastive: | |||||
| pass | |||||
| con = None | |||||
| else: | |||||
| con = None | |||||
| metrics['loss'] = loss | |||||
| if self.gpu > 1: | |||||
| return intent_loss, mlm, token_mlm, mlm_num, kl, con | |||||
| else: | |||||
| return metrics | |||||
| def _infer(self, | |||||
| inputs, | |||||
| start_id=None, | |||||
| eos_id=None, | |||||
| max_gen_len=None, | |||||
| prev_input=None): | |||||
| """ Real inference process of model. """ | |||||
| results = {} | |||||
| enc_embed, dec_embed = self._encoder_decoder_network( | |||||
| src_token=inputs['src_token'], | |||||
| src_mask=inputs['src_mask'], | |||||
| tgt_token=inputs['tgt_token'], | |||||
| tgt_mask=inputs['tgt_mask'], | |||||
| src_pos=inputs['src_pos'], | |||||
| src_type=inputs['src_type'], | |||||
| src_turn=inputs['src_turn']) | |||||
| features = dec_embed[:, -1] | |||||
| features = self.pooler(features) if self.with_pool else features | |||||
| if self.example: | |||||
| results['features'] = features | |||||
| else: | |||||
| intent_logits = self.intent_classifier(features) | |||||
| intent_probs = self.softmax(intent_logits) | |||||
| results['intent_probs'] = intent_probs | |||||
| return results | |||||
| IntentUnifiedTransformer.register('IntentUnifiedTransformer') | |||||
| @@ -0,0 +1,101 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import torch.nn as nn | |||||
| from .....utils.constant import ModelFile | |||||
| class SpaceModelBase(nn.Module): | |||||
| """ | |||||
| Basic model wrapper for static graph and dygrpah. | |||||
| """ | |||||
| _registry = dict() | |||||
| @classmethod | |||||
| def register(cls, name): | |||||
| SpaceModelBase._registry[name] = cls | |||||
| return | |||||
| @staticmethod | |||||
| def by_name(name): | |||||
| return SpaceModelBase._registry[name] | |||||
| @staticmethod | |||||
| def create(model_dir, config, *args, **kwargs): | |||||
| model_cls = SpaceModelBase.by_name(config.Model.model) | |||||
| return model_cls(model_dir, config, *args, **kwargs) | |||||
| def __init__(self, model_dir, config): | |||||
| super(SpaceModelBase, self).__init__() | |||||
| self.init_checkpoint = os.path.join(model_dir, | |||||
| ModelFile.TORCH_MODEL_BIN_FILE) | |||||
| self.abandon_label = config.Dataset.abandon_label | |||||
| self.use_gpu = config.use_gpu | |||||
| self.gpu = config.Trainer.gpu | |||||
| return | |||||
| def _create_parameters(self): | |||||
| """ Create model's paramters. """ | |||||
| raise NotImplementedError | |||||
| def _forward(self, inputs, is_training, with_label): | |||||
| """ NO LABEL: Real forward process of model in different mode(train/test). """ | |||||
| raise NotImplementedError | |||||
| def _collect_metrics(self, inputs, outputs, with_label, data_file): | |||||
| """ NO LABEL: Calculate loss function by using inputs and outputs. """ | |||||
| raise NotImplementedError | |||||
| def _optimize(self, loss, optimizer, lr_scheduler): | |||||
| """ Optimize loss function and update model. """ | |||||
| raise NotImplementedError | |||||
| def _infer(self, inputs, start_id, eos_id, max_gen_len, prev_input): | |||||
| """ Real inference process of model. """ | |||||
| raise NotImplementedError | |||||
| def forward(self, | |||||
| inputs, | |||||
| is_training=False, | |||||
| with_label=False, | |||||
| data_file=None): | |||||
| """ | |||||
| Forward process, include real forward, collect metrices and optimize(optional) | |||||
| @params : inputs : input data | |||||
| @type : dict of numpy.ndarray/int/float/... | |||||
| """ | |||||
| if is_training: | |||||
| self.train() | |||||
| else: | |||||
| self.eval() | |||||
| with_label = False if self.abandon_label else with_label | |||||
| outputs = self._forward(inputs, is_training, with_label=with_label) | |||||
| metrics = self._collect_metrics( | |||||
| inputs, outputs, with_label=with_label, data_file=data_file) | |||||
| return metrics | |||||
| def infer(self, | |||||
| inputs, | |||||
| start_id=None, | |||||
| eos_id=None, | |||||
| max_gen_len=None, | |||||
| prev_input=None): | |||||
| """ | |||||
| Inference process. | |||||
| @params : inputs : input data | |||||
| @type : dict of numpy.ndarray/int/float/... | |||||
| """ | |||||
| self.eval() | |||||
| results = self._infer( | |||||
| inputs, | |||||
| start_id=start_id, | |||||
| eos_id=eos_id, | |||||
| max_gen_len=max_gen_len, | |||||
| prev_input=prev_input) | |||||
| return results | |||||
| @@ -0,0 +1,313 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from ..modules.embedder import Embedder | |||||
| from ..modules.transformer_block import TransformerBlock | |||||
| from .model_base import SpaceModelBase | |||||
| class UnifiedTransformer(SpaceModelBase): | |||||
| """ | |||||
| Implement unified transformer. | |||||
| """ | |||||
| def __init__(self, model_dir, config, reader, generator, dtype='float32'): | |||||
| super(UnifiedTransformer, self).__init__(model_dir, config) | |||||
| self.reader = reader | |||||
| self.generator = generator | |||||
| self.policy = config.BPETextField.policy | |||||
| self.generation = config.BPETextField.generation | |||||
| self.num_token_embeddings = config.Model.num_token_embeddings | |||||
| self.num_pos_embeddings = config.Model.num_pos_embeddings | |||||
| self.num_type_embeddings = config.Model.num_type_embeddings | |||||
| self.num_turn_embeddings = config.Model.num_turn_embeddings | |||||
| self.temperature = config.Model.temperature | |||||
| self.hidden_dim = config.Model.hidden_dim | |||||
| self.num_heads = config.Model.num_heads | |||||
| self.num_layers = config.Model.num_layers | |||||
| self.padding_idx = config.Model.padding_idx | |||||
| self.dropout = config.Model.dropout | |||||
| self.embed_dropout = config.Model.embed_dropout | |||||
| self.attn_dropout = config.Model.attn_dropout | |||||
| self.ff_dropout = config.Model.ff_dropout | |||||
| self.mlm_ratio = config.Model.mlm_ratio | |||||
| self.mmd_ratio = config.Model.mmd_ratio | |||||
| self.pos_trainable = config.Model.pos_trainable | |||||
| self.label_smooth = config.Model.label_smooth | |||||
| self.initializer_range = config.Model.initializer_range | |||||
| self.gradient_accumulation_steps = config.Model.gradient_accumulation_steps | |||||
| self.token_loss = config.Trainer.token_loss | |||||
| self.learning_method = config.Dataset.learning_method | |||||
| self.with_contrastive = config.Dataset.with_contrastive | |||||
| self.with_query_bow = config.BPETextField.with_query_bow | |||||
| self.with_resp_bow = config.BPETextField.with_resp_bow | |||||
| self.with_pool = config.Model.with_pool | |||||
| self.with_mlm = config.Dataset.with_mlm | |||||
| self._dtype = dtype | |||||
| self.embedder = Embedder( | |||||
| self.hidden_dim, | |||||
| self.num_token_embeddings, | |||||
| self.num_pos_embeddings, | |||||
| self.num_type_embeddings, | |||||
| self.num_turn_embeddings, | |||||
| padding_idx=self.padding_idx, | |||||
| dropout=self.embed_dropout, | |||||
| pos_trainable=self.pos_trainable) | |||||
| self.embed_layer_norm = nn.LayerNorm( | |||||
| normalized_shape=self.hidden_dim, | |||||
| eps=1e-12, | |||||
| elementwise_affine=True) | |||||
| self.layers = nn.ModuleList([ | |||||
| TransformerBlock(self.hidden_dim, self.num_heads, self.dropout, | |||||
| self.attn_dropout, self.ff_dropout) | |||||
| for _ in range(config.Model.num_layers) | |||||
| ]) | |||||
| if self.with_mlm: | |||||
| self.mlm_transform = nn.Sequential( | |||||
| nn.Linear(self.hidden_dim, self.hidden_dim), nn.GELU(), | |||||
| nn.LayerNorm( | |||||
| normalized_shape=self.hidden_dim, | |||||
| eps=1e-12, | |||||
| elementwise_affine=True)) | |||||
| self.mlm_bias = nn.Parameter( | |||||
| torch.zeros(self.num_token_embeddings)) | |||||
| self.pooler = nn.Sequential( | |||||
| nn.Linear(self.hidden_dim, self.hidden_dim), nn.Tanh()) | |||||
| if self.with_query_bow or self.with_resp_bow: | |||||
| self.bow_predictor = nn.Linear( | |||||
| self.hidden_dim, self.num_token_embeddings, bias=False) | |||||
| self.sigmoid = nn.Sigmoid() | |||||
| self.softmax = nn.Softmax(dim=-1) | |||||
| self.bce_loss = nn.BCELoss(reduction='none') | |||||
| self.nll_loss = nn.NLLLoss( | |||||
| ignore_index=self.padding_idx, reduction='none') | |||||
| self._create_parameters() | |||||
| self.max_grad_norm = config.Model.max_grad_norm | |||||
| if self.max_grad_norm is not None: | |||||
| self.grad_clip = self.max_grad_norm | |||||
| else: | |||||
| self.grad_clip = None | |||||
| self.weight_decay = config.Model.weight_decay | |||||
| if self.use_gpu: | |||||
| self.cuda() | |||||
| return | |||||
| def _create_parameters(self): | |||||
| """ Create model's paramters. """ | |||||
| sequence_mask = np.tri( | |||||
| self.num_pos_embeddings, | |||||
| self.num_pos_embeddings, | |||||
| dtype=self._dtype) | |||||
| self.sequence_mask = torch.tensor(sequence_mask) | |||||
| return | |||||
| def _create_mask(self, | |||||
| input_mask, | |||||
| append_head=False, | |||||
| auto_regressive=False): | |||||
| """ | |||||
| Create attention mask. | |||||
| from sequence to matrix:[batch_size, max_seq_len, 1] -> [batch_size, max_seq_len, max_seq_len] | |||||
| @param : input_mask | |||||
| @type : Variable(shape: [batch_size, max_seq_len]) | |||||
| @param : auto_regressive | |||||
| @type : bool | |||||
| """ | |||||
| seq_len = input_mask.shape[1] | |||||
| input_mask = input_mask.float() | |||||
| mask1 = input_mask.unsqueeze(-1).repeat(1, 1, seq_len) | |||||
| mask2 = mask1.permute(0, 2, 1) | |||||
| mask = mask1 * mask2 | |||||
| if append_head: | |||||
| mask = torch.cat([mask[:, :1, :], mask], dim=1) | |||||
| mask = torch.cat([mask[:, :, :1], mask], dim=2) | |||||
| seq_len += 1 | |||||
| if auto_regressive: | |||||
| seq_mask = self.sequence_mask[:seq_len, :seq_len] | |||||
| seq_mask = seq_mask.to(mask.device) | |||||
| mask = mask * seq_mask | |||||
| mask = 1 - mask | |||||
| return mask | |||||
| def _join_mask(self, mask1, mask2): | |||||
| """ | |||||
| Merge source attention mask and target attention mask. | |||||
| There are four parts:left upper (lu) / right upper (ru) / left below (lb) / right below (rb) | |||||
| @param : mask1 : source attention mask | |||||
| @type : Variable(shape: [batch_size, max_src_len, max_src_len]) | |||||
| @param : mask1 : target attention mask | |||||
| @type : Variable(shape: [batch_size, max_tgt_len, max_tgt_len]) | |||||
| """ | |||||
| batch_size = mask1.shape[0] | |||||
| seq_len1 = mask1.shape[1] | |||||
| seq_len2 = mask2.shape[1] | |||||
| # seq_len = seq_len1 + seq_len2 | |||||
| mask_lu = mask1 | |||||
| mask_ru = torch.ones(batch_size, seq_len1, seq_len2) | |||||
| if self.use_gpu: | |||||
| mask_ru = mask_ru.cuda() | |||||
| mask3 = mask2[:, :, :1].repeat(1, 1, seq_len1) | |||||
| mask4 = mask1[:, :1].repeat(1, seq_len2, 1) | |||||
| mask_lb = mask3 + mask4 - mask3 * mask4 | |||||
| mask_rb = mask2 | |||||
| mask_u = torch.cat([mask_lu, mask_ru], dim=2) | |||||
| mask_b = torch.cat([mask_lb, mask_rb], dim=2) | |||||
| mask = torch.cat([mask_u, mask_b], dim=1) | |||||
| return mask | |||||
| def _mlm_head(self, mlm_embed): | |||||
| mlm_embed = self.mlm_transform(mlm_embed) | |||||
| mlm_logits = torch.matmul( | |||||
| mlm_embed, self.embedder.token_embedding.weight.T) + self.mlm_bias | |||||
| mlm_probs = self.softmax(mlm_logits) | |||||
| return mlm_probs | |||||
| def _dec_head(self, dec_embed): | |||||
| dec_logits = torch.matmul(dec_embed, | |||||
| self.embedder.token_embedding.weight.T) | |||||
| dec_probs = self.softmax(dec_logits) | |||||
| return dec_probs | |||||
| def _refactor_feature(self, features): | |||||
| features = self.pooler(features) if self.with_pool else features | |||||
| batch_size = features.size(0) // 2 | |||||
| features = \ | |||||
| torch.cat( | |||||
| [features[:batch_size].unsqueeze(1), features[batch_size:].unsqueeze(1)], | |||||
| dim=1 | |||||
| ) | |||||
| features = F.normalize(features, dim=-1, p=2) | |||||
| return features | |||||
| def _encoder_network(self, | |||||
| input_token, | |||||
| input_mask, | |||||
| input_pos=None, | |||||
| input_type=None, | |||||
| input_turn=None): | |||||
| embed = self.embedder(input_token, input_pos, input_type, input_turn) | |||||
| embed = self.embed_layer_norm(embed) | |||||
| mask = self._create_mask(input_mask, auto_regressive=False) | |||||
| for layer in self.layers: | |||||
| embed = layer(embed, mask, None) | |||||
| return embed | |||||
| def _encoder_decoder_network(self, | |||||
| src_token, | |||||
| src_mask, | |||||
| tgt_token, | |||||
| tgt_mask, | |||||
| src_pos=None, | |||||
| src_type=None, | |||||
| src_turn=None, | |||||
| tgt_pos=None, | |||||
| tgt_type=None, | |||||
| tgt_turn=None): | |||||
| src_embed = self.embedder(src_token, src_pos, src_type, src_turn) | |||||
| tgt_embed = self.embedder(tgt_token, tgt_pos, tgt_type, tgt_turn) | |||||
| embed = torch.cat([src_embed, tgt_embed], dim=1) | |||||
| embed = self.embed_layer_norm(embed) | |||||
| enc_mask = self._create_mask(src_mask, auto_regressive=False) | |||||
| dec_mask = self._create_mask(tgt_mask, auto_regressive=True) | |||||
| mask = self._join_mask(enc_mask, dec_mask) | |||||
| for layer in self.layers: | |||||
| embed = layer(embed, mask, None) | |||||
| tgt_len = tgt_token.shape[1] | |||||
| enc_embed = embed[:, :-tgt_len] | |||||
| dec_embed = embed[:, -tgt_len:] | |||||
| return enc_embed, dec_embed | |||||
| def _encoder_prompt_decoder_network(self, | |||||
| src_token, | |||||
| src_mask, | |||||
| tgt_token, | |||||
| tgt_mask, | |||||
| prompt_token, | |||||
| prompt_mask, | |||||
| src_pos=None, | |||||
| src_type=None, | |||||
| src_turn=None, | |||||
| tgt_pos=None, | |||||
| tgt_type=None, | |||||
| tgt_turn=None, | |||||
| prompt_pos=None, | |||||
| prompt_type=None, | |||||
| prompt_turn=None): | |||||
| src_embed = self.embedder(src_token, src_pos, src_type, src_turn) | |||||
| tgt_embed = self.embedder(tgt_token, tgt_pos, tgt_type, tgt_turn) | |||||
| prompt_embed = self.embedder(prompt_token, prompt_pos, prompt_type, | |||||
| prompt_turn) | |||||
| embed = torch.cat([src_embed, prompt_embed, tgt_embed], dim=1) | |||||
| embed = self.embed_layer_norm(embed) | |||||
| enc_mask = self._create_mask(src_mask, auto_regressive=False) | |||||
| dec_mask = self._create_mask( | |||||
| torch.cat([prompt_mask, tgt_mask], dim=1), auto_regressive=True) | |||||
| mask = self._join_mask(enc_mask, dec_mask) | |||||
| for layer in self.layers: | |||||
| embed = layer(embed, mask, None) | |||||
| src_len = src_token.shape[1] | |||||
| tgt_len = tgt_token.shape[1] | |||||
| enc_embed = embed[:, :src_len] | |||||
| dec_embed = embed[:, -tgt_len:] | |||||
| prompt_embed = embed[:, src_len:-tgt_len] | |||||
| return enc_embed, dec_embed, prompt_embed | |||||
| def _optimize(self, loss, optimizer=None, lr_scheduler=None): | |||||
| """ Optimize loss function and update model. """ | |||||
| assert optimizer is not None | |||||
| optimizer.zero_grad() | |||||
| loss.backward() | |||||
| if self.grad_clip is not None and self.grad_clip > 0: | |||||
| torch.nn.utils.clip_grad_norm_( | |||||
| parameters=self.parameters(), max_norm=self.grad_clip) | |||||
| optimizer.step() | |||||
| if lr_scheduler is not None: | |||||
| lr_scheduler.step() | |||||
| return | |||||
| def _infer(self, | |||||
| inputs, | |||||
| start_id=None, | |||||
| eos_id=None, | |||||
| max_gen_len=None, | |||||
| prev_input=None): | |||||
| """ Real inference process of model. """ | |||||
| results = {} | |||||
| return results | |||||
| UnifiedTransformer.register('UnifiedTransformer') | |||||
| @@ -0,0 +1,65 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| class Embedder(nn.Module): | |||||
| """ | |||||
| Composite embedding layer. | |||||
| """ | |||||
| def __init__(self, | |||||
| hidden_dim, | |||||
| num_token_embeddings, | |||||
| num_pos_embeddings, | |||||
| num_type_embeddings, | |||||
| num_turn_embeddings, | |||||
| padding_idx=None, | |||||
| dropout=0.1, | |||||
| pos_trainable=False): | |||||
| super(Embedder, self).__init__() | |||||
| self.token_embedding = nn.Embedding(num_token_embeddings, hidden_dim) | |||||
| self.pos_embedding = nn.Embedding(num_pos_embeddings, hidden_dim) | |||||
| self.pos_embedding.weight.requires_grad = pos_trainable | |||||
| self.type_embedding = nn.Embedding(num_type_embeddings, hidden_dim) | |||||
| self.turn_embedding = nn.Embedding(num_turn_embeddings, hidden_dim) | |||||
| self.dropout_layer = nn.Dropout(p=dropout) | |||||
| # follow the default xavier_uniform initializer in paddle version | |||||
| # otherwise, there are bugs for dec_probs computation in weight typing setting | |||||
| # default norm initializer in nn.Embedding in pytorch, which samples larger values | |||||
| nn.init.xavier_uniform_(self.token_embedding.weight) | |||||
| nn.init.xavier_uniform_(self.pos_embedding.weight) | |||||
| nn.init.xavier_uniform_(self.type_embedding.weight) | |||||
| nn.init.xavier_uniform_(self.turn_embedding.weight) | |||||
| return | |||||
| def forward(self, token_inp, pos_inp=None, type_inp=None, turn_inp=None): | |||||
| embed = self.token_embedding(token_inp) | |||||
| if pos_inp is not None: | |||||
| embed += self.pos_embedding(pos_inp) | |||||
| if type_inp is not None: | |||||
| embed += self.type_embedding(type_inp) | |||||
| if turn_inp is not None: | |||||
| embed += self.turn_embedding(turn_inp) | |||||
| embed = self.dropout_layer(embed) | |||||
| return embed | |||||
| def main(): | |||||
| import numpy as np | |||||
| model = Embedder(10, 20, 20, 20, 20) | |||||
| token_inp = torch.tensor( | |||||
| np.random.randint(0, 19, [10, 10]).astype('int64')) | |||||
| pos_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) | |||||
| type_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) | |||||
| turn_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) | |||||
| out = model(token_inp, pos_inp, type_inp, turn_inp) | |||||
| print(out) | |||||
| if __name__ == '__main__': | |||||
| main() | |||||
| @@ -0,0 +1,41 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| class FeedForward(nn.Module): | |||||
| """ | |||||
| Positional feed forward layer. | |||||
| """ | |||||
| def __init__(self, hidden_dim, inner_dim, dropout): | |||||
| super(FeedForward, self).__init__() | |||||
| self.hidden_dim = hidden_dim | |||||
| self.inner_dim = inner_dim | |||||
| self.linear_hidden = nn.Sequential( | |||||
| nn.Linear(hidden_dim, inner_dim), nn.GELU()) | |||||
| self.linear_out = nn.Linear(inner_dim, hidden_dim) | |||||
| self.dropout_layer = nn.Dropout(p=dropout) | |||||
| return | |||||
| def forward(self, x): | |||||
| out = self.linear_hidden(x) | |||||
| out = self.dropout_layer(out) | |||||
| out = self.linear_out(out) | |||||
| return out | |||||
| def main(): | |||||
| import numpy as np | |||||
| model = FeedForward(10, 20, 0.5) | |||||
| inp = np.random.rand(2, 3, 10).astype('float32') | |||||
| inp = torch.tensor(inp) | |||||
| out = model(inp) | |||||
| print(out) | |||||
| if __name__ == '__main__': | |||||
| main() | |||||
| @@ -0,0 +1,62 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| def unsqueeze(input, dims): | |||||
| """ Implement multi-dimension unsqueeze function. """ | |||||
| if isinstance(dims, (list, tuple)): | |||||
| dims = [ | |||||
| dim if dim >= 0 else dim + len(input.shape) + 1 for dim in dims | |||||
| ] | |||||
| dims = sorted(dims, reverse=True) | |||||
| shape = list(input.shape) | |||||
| for dim in dims: | |||||
| shape.insert(dim, 1) | |||||
| return torch.reshape(input, shape) | |||||
| elif isinstance(dims, int): | |||||
| return input.unsqueeze(dims) | |||||
| else: | |||||
| raise ValueError('Warning: type(dims) must in (list, tuple, int)!') | |||||
| def gumbel_softmax(input, tau=1, eps=1e-10): | |||||
| """ Basic implement of gumbel_softmax. """ | |||||
| U = torch.tensor(np.random.rand(*input.shape)) | |||||
| gumbel = 0.0 - torch.log(eps - torch.log(U + eps)) | |||||
| y = input + gumbel | |||||
| return F.softmax(y / tau) | |||||
| def equal(x, y, dtype=None): | |||||
| """ Implement equal in dygraph mode. (paddle) """ | |||||
| if dtype is None: | |||||
| dtype = 'float32' | |||||
| if isinstance(x, torch.Tensor): | |||||
| x = x.numpy() | |||||
| if isinstance(y, torch.Tensor): | |||||
| y = y.numpy() | |||||
| out = np.equal(x, y).astype(dtype) | |||||
| return torch.tensor(out) | |||||
| def not_equal(x, y, dtype=None): | |||||
| """ Implement not_equal in dygraph mode. (paddle) """ | |||||
| return 1 - equal(x, y, dtype) | |||||
| if __name__ == '__main__': | |||||
| a = torch.tensor([[1, 1], [3, 4]]) | |||||
| b = torch.tensor([[1, 1], [3, 4]]) | |||||
| c = torch.equal(a, a) | |||||
| c1 = equal(a, 3) | |||||
| d = 1 - torch.not_equal(a, 3).float() | |||||
| print(c) | |||||
| print(c1) | |||||
| print(d) | |||||
| e = F.gumbel_softmax(a) | |||||
| f = a.unsqueeze(a) | |||||
| g = unsqueeze(a, dims=[0, 0, 1]) | |||||
| print(g, g.shape) | |||||
| @@ -0,0 +1,105 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| class MultiheadAttention(nn.Module): | |||||
| """ | |||||
| Multi head attention layer. | |||||
| """ | |||||
| def __init__(self, hidden_dim, num_heads, dropout): | |||||
| assert hidden_dim % num_heads == 0 | |||||
| super(MultiheadAttention, self).__init__() | |||||
| self.hidden_dim = hidden_dim | |||||
| self.num_heads = num_heads | |||||
| self.head_dim = hidden_dim // num_heads | |||||
| self.scale = self.head_dim**-0.5 | |||||
| self.linear_qkv = nn.Linear(hidden_dim, hidden_dim * 3) | |||||
| self.linear_out = nn.Linear(hidden_dim, hidden_dim) | |||||
| self.dropout_layer = nn.Dropout(p=dropout) | |||||
| self.softmax = nn.Softmax(dim=-1) | |||||
| return | |||||
| def _split_heads(self, x, is_key=False): | |||||
| x = x.reshape(x.size(0), x.size(1), self.num_heads, self.head_dim) | |||||
| x = x.permute(0, 2, 3, 1) if is_key else x.permute(0, 2, 1, 3) | |||||
| return x | |||||
| def _merge_heads(self, x): | |||||
| x = x.permute(0, 2, 1, 3) | |||||
| x = x.reshape(x.size(0), x.size(1), self.hidden_dim) | |||||
| return x | |||||
| def _attn(self, query, key, value, mask): | |||||
| # shape: [batch_size, num_head, seq_len, seq_len] | |||||
| scores = torch.matmul(query, key) | |||||
| scores = scores * self.scale | |||||
| if mask is not None: | |||||
| mask = mask.unsqueeze(1) | |||||
| mask = mask.repeat(1, self.num_heads, 1, 1) | |||||
| scores.masked_fill_( | |||||
| mask.bool(), | |||||
| float('-inf')) # scores = (1 - mask) * scores + mask * (-1e10) | |||||
| attn = self.softmax(scores) | |||||
| attn = self.dropout_layer(attn) | |||||
| if mask is not None: | |||||
| ''' | |||||
| mask: [batch size, num_heads, seq_len, seq_len] | |||||
| >>> F.softmax([-1e10, -100, -100]) | |||||
| >>> [0.00, 0.50, 0.50] | |||||
| >>> F.softmax([-1e10, -1e10, -1e10]) | |||||
| >>> [0.33, 0.33, 0.33] | |||||
| ==> [0.00, 0.00, 0.00] | |||||
| ''' | |||||
| attn.masked_fill_(mask.bool(), 0.) # attn = (1 - mask) * attn | |||||
| out = torch.matmul(attn, value) | |||||
| return out | |||||
| def forward(self, inp, mask=None, cache=None): | |||||
| """ Forward process of self attention. """ | |||||
| # shape: [batch_size, seq_len, 3 * hidden_dim] | |||||
| qkv = self.linear_qkv(inp) | |||||
| query, key, value = torch.split(qkv, self.hidden_dim, dim=2) | |||||
| # shape: [batch_size, num_head, seq_len, head_dim] | |||||
| query = self._split_heads(query) | |||||
| # shape: [batch_size, num_head, head_dim, seq_len] | |||||
| key = self._split_heads(key, is_key=True) | |||||
| # shape: [batch_size, num_head, seq_len, head_dim] | |||||
| value = self._split_heads(value) | |||||
| if cache is not None: | |||||
| if 'key' in cache and 'value' in cache: | |||||
| key = torch.cat([cache['key'], key], dim=3) | |||||
| value = torch.cat([cache['value'], value], dim=2) | |||||
| cache['key'] = key | |||||
| cache['value'] = value | |||||
| out = self._attn(query, key, value, mask) | |||||
| out = self._merge_heads(out) | |||||
| out = self.linear_out(out) | |||||
| return out | |||||
| def main(): | |||||
| import numpy as np | |||||
| model = MultiheadAttention(10, 2, 0.5) | |||||
| inp = np.random.rand(2, 3, 10).astype('float32') | |||||
| inp = torch.tensor(inp) | |||||
| mask = (np.random.rand(2, 3, 3) > 0.5).astype('float32') | |||||
| mask = torch.tensor(mask) | |||||
| out = model(inp, mask=mask, cache=None) | |||||
| print(out) | |||||
| if __name__ == '__main__': | |||||
| main() | |||||
| @@ -0,0 +1,70 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from .feedforward import FeedForward | |||||
| from .multihead_attention import MultiheadAttention | |||||
| class TransformerBlock(nn.Module): | |||||
| """ | |||||
| Transformer block module. | |||||
| """ | |||||
| def __init__(self, hidden_dim, num_heads, dropout, attn_dropout, | |||||
| ff_dropout): | |||||
| super(TransformerBlock, self).__init__() | |||||
| self.attn = MultiheadAttention( | |||||
| hidden_dim=hidden_dim, num_heads=num_heads, dropout=attn_dropout) | |||||
| self.attn_norm = nn.LayerNorm( | |||||
| normalized_shape=hidden_dim, eps=1e-12, elementwise_affine=True) | |||||
| self.ff = FeedForward( | |||||
| hidden_dim=hidden_dim, | |||||
| inner_dim=4 * hidden_dim, | |||||
| dropout=ff_dropout) | |||||
| self.ff_norm = nn.LayerNorm( | |||||
| normalized_shape=hidden_dim, eps=1e-12, elementwise_affine=True) | |||||
| self.dropout_layer = nn.Dropout(p=dropout) | |||||
| return | |||||
| def forward(self, inp, mask=None, cache=None): | |||||
| """ | |||||
| Forward process on one transformer layer. | |||||
| @param : x | |||||
| @type : Variable(shape: [batch_size, seq_len, hidden_size]) | |||||
| @param : memory | |||||
| @type : Variable(shape: [batch_size, seq_len, hidden_size]) | |||||
| @param : mask | |||||
| @param : cache | |||||
| """ | |||||
| attn_out = self.attn(inp, mask, cache) | |||||
| attn_out = self.dropout_layer(attn_out) | |||||
| attn_out = self.attn_norm(attn_out + inp) | |||||
| ff_out = self.ff(attn_out) | |||||
| ff_out = self.dropout_layer(ff_out) | |||||
| ff_out = self.ff_norm(ff_out + attn_out) | |||||
| return ff_out | |||||
| def main(): | |||||
| import numpy as np | |||||
| model = TransformerBlock(10, 2, 0.5, 0.5, 0.5) | |||||
| inp = np.random.rand(2, 3, 10).astype('float32') | |||||
| inp = torch.tensor(inp) | |||||
| mask = (np.random.rand(2, 3, 3) > 0.5).astype('float32') | |||||
| mask = torch.tensor(mask) | |||||
| out = model(inp, mask=mask, cache=None) | |||||
| print(out) | |||||
| if __name__ == '__main__': | |||||
| main() | |||||
| @@ -21,6 +21,12 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| Tasks.sentence_similarity: | Tasks.sentence_similarity: | ||||
| (Pipelines.sentence_similarity, | (Pipelines.sentence_similarity, | ||||
| 'damo/nlp_structbert_sentence-similarity_chinese-base'), | 'damo/nlp_structbert_sentence-similarity_chinese-base'), | ||||
| Tasks.nli: (Pipelines.nli, 'damo/nlp_structbert_nli_chinese-base'), | |||||
| Tasks.sentiment_classification: | |||||
| (Pipelines.sentiment_classification, | |||||
| 'damo/nlp_structbert_sentiment-classification_chinese-base'), | |||||
| Tasks.text_classification: ('bert-sentiment-analysis', | |||||
| 'damo/bert-base-sst2'), | |||||
| Tasks.image_matting: (Pipelines.image_matting, | Tasks.image_matting: (Pipelines.image_matting, | ||||
| 'damo/cv_unet_image-matting'), | 'damo/cv_unet_image-matting'), | ||||
| Tasks.text_classification: (Pipelines.sentiment_analysis, | Tasks.text_classification: (Pipelines.sentiment_analysis, | ||||
| @@ -30,6 +36,11 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| Tasks.zero_shot_classification: | Tasks.zero_shot_classification: | ||||
| (Pipelines.zero_shot_classification, | (Pipelines.zero_shot_classification, | ||||
| 'damo/nlp_structbert_zero-shot-classification_chinese-base'), | 'damo/nlp_structbert_zero-shot-classification_chinese-base'), | ||||
| Tasks.dialog_intent_prediction: | |||||
| (Pipelines.dialog_intent_prediction, | |||||
| 'damo/nlp_space_dialog-intent-prediction'), | |||||
| Tasks.dialog_modeling: (Pipelines.dialog_modeling, | |||||
| 'damo/nlp_space_dialog-modeling'), | |||||
| Tasks.image_captioning: (Pipelines.image_caption, | Tasks.image_captioning: (Pipelines.image_caption, | ||||
| 'damo/ofa_image-caption_coco_large_en'), | 'damo/ofa_image-caption_coco_large_en'), | ||||
| Tasks.image_generation: | Tasks.image_generation: | ||||
| @@ -1,6 +1,10 @@ | |||||
| try: | try: | ||||
| from .dialog_intent_prediction_pipeline import * # noqa F403 | |||||
| from .dialog_modeling_pipeline import * # noqa F403 | |||||
| from .fill_mask_pipeline import * # noqa F403 | from .fill_mask_pipeline import * # noqa F403 | ||||
| from .nli_pipeline import * # noqa F403 | |||||
| from .sentence_similarity_pipeline import * # noqa F403 | from .sentence_similarity_pipeline import * # noqa F403 | ||||
| from .sentiment_classification_pipeline import * # noqa F403 | |||||
| from .sequence_classification_pipeline import * # noqa F403 | from .sequence_classification_pipeline import * # noqa F403 | ||||
| from .text_generation_pipeline import * # noqa F403 | from .text_generation_pipeline import * # noqa F403 | ||||
| from .word_segmentation_pipeline import * # noqa F403 | from .word_segmentation_pipeline import * # noqa F403 | ||||
| @@ -0,0 +1,53 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import Any, Dict | |||||
| from ...metainfo import Pipelines | |||||
| from ...models.nlp import SpaceForDialogIntent | |||||
| from ...preprocessors import DialogIntentPredictionPreprocessor | |||||
| from ...utils.constant import Tasks | |||||
| from ..base import Pipeline | |||||
| from ..builder import PIPELINES | |||||
| from ..outputs import OutputKeys | |||||
| __all__ = ['DialogIntentPredictionPipeline'] | |||||
| @PIPELINES.register_module( | |||||
| Tasks.dialog_intent_prediction, | |||||
| module_name=Pipelines.dialog_intent_prediction) | |||||
| class DialogIntentPredictionPipeline(Pipeline): | |||||
| def __init__(self, model: SpaceForDialogIntent, | |||||
| preprocessor: DialogIntentPredictionPreprocessor, **kwargs): | |||||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||||
| Args: | |||||
| model (SequenceClassificationModel): a model instance | |||||
| preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||||
| """ | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.model = model | |||||
| self.categories = preprocessor.categories | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||||
| """process the prediction results | |||||
| Args: | |||||
| inputs (Dict[str, Any]): _description_ | |||||
| Returns: | |||||
| Dict[str, str]: the prediction results | |||||
| """ | |||||
| import numpy as np | |||||
| pred = inputs['pred'] | |||||
| pos = np.where(pred == np.max(pred)) | |||||
| result = { | |||||
| OutputKeys.PREDICTION: pred, | |||||
| OutputKeys.LABEL_POS: pos[0], | |||||
| OutputKeys.LABEL: self.categories[pos[0][0]] | |||||
| } | |||||
| return result | |||||
| @@ -0,0 +1,49 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import Any, Dict, Optional | |||||
| from ...metainfo import Pipelines | |||||
| from ...models.nlp import SpaceForDialogModeling | |||||
| from ...preprocessors import DialogModelingPreprocessor | |||||
| from ...utils.constant import Tasks | |||||
| from ..base import Pipeline, Tensor | |||||
| from ..builder import PIPELINES | |||||
| from ..outputs import OutputKeys | |||||
| __all__ = ['DialogModelingPipeline'] | |||||
| @PIPELINES.register_module( | |||||
| Tasks.dialog_modeling, module_name=Pipelines.dialog_modeling) | |||||
| class DialogModelingPipeline(Pipeline): | |||||
| def __init__(self, model: SpaceForDialogModeling, | |||||
| preprocessor: DialogModelingPreprocessor, **kwargs): | |||||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||||
| Args: | |||||
| model (SequenceClassificationModel): a model instance | |||||
| preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||||
| """ | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.model = model | |||||
| self.preprocessor = preprocessor | |||||
| def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: | |||||
| """process the prediction results | |||||
| Args: | |||||
| inputs (Dict[str, Any]): _description_ | |||||
| Returns: | |||||
| Dict[str, str]: the prediction results | |||||
| """ | |||||
| sys_rsp = self.preprocessor.text_field.tokenizer.convert_ids_to_tokens( | |||||
| inputs['resp']) | |||||
| assert len(sys_rsp) > 2 | |||||
| sys_rsp = sys_rsp[1:len(sys_rsp) - 1] | |||||
| inputs[OutputKeys.RESPONSE] = sys_rsp | |||||
| return inputs | |||||
| @@ -1,5 +1,7 @@ | |||||
| import os | import os | ||||
| from typing import Dict, Optional, Union | |||||
| from typing import Any, Dict, Optional, Union | |||||
| import torch | |||||
| from ...metainfo import Pipelines | from ...metainfo import Pipelines | ||||
| from ...models import Model | from ...models import Model | ||||
| @@ -21,6 +23,7 @@ class FillMaskPipeline(Pipeline): | |||||
| def __init__(self, | def __init__(self, | ||||
| model: Union[MaskedLanguageModelBase, str], | model: Union[MaskedLanguageModelBase, str], | ||||
| preprocessor: Optional[FillMaskPreprocessor] = None, | preprocessor: Optional[FillMaskPreprocessor] = None, | ||||
| first_sequence='sentense', | |||||
| **kwargs): | **kwargs): | ||||
| """use `model` and `preprocessor` to create a nlp fill mask pipeline for prediction | """use `model` and `preprocessor` to create a nlp fill mask pipeline for prediction | ||||
| @@ -30,12 +33,16 @@ class FillMaskPipeline(Pipeline): | |||||
| """ | """ | ||||
| fill_mask_model = model if isinstance( | fill_mask_model = model if isinstance( | ||||
| model, MaskedLanguageModelBase) else Model.from_pretrained(model) | model, MaskedLanguageModelBase) else Model.from_pretrained(model) | ||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = FillMaskPreprocessor( | preprocessor = FillMaskPreprocessor( | ||||
| fill_mask_model.model_dir, | fill_mask_model.model_dir, | ||||
| first_sequence='sentence', | |||||
| first_sequence=first_sequence, | |||||
| second_sequence=None) | second_sequence=None) | ||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| fill_mask_model.eval() | |||||
| super().__init__( | |||||
| model=fill_mask_model, preprocessor=preprocessor, **kwargs) | |||||
| self.preprocessor = preprocessor | self.preprocessor = preprocessor | ||||
| self.config = Config.from_file( | self.config = Config.from_file( | ||||
| os.path.join(fill_mask_model.model_dir, ModelFile.CONFIGURATION)) | os.path.join(fill_mask_model.model_dir, ModelFile.CONFIGURATION)) | ||||
| @@ -63,6 +70,11 @@ class FillMaskPipeline(Pipeline): | |||||
| } | } | ||||
| } | } | ||||
| def forward(self, inputs: Dict[str, Any], | |||||
| **forward_params) -> Dict[str, Any]: | |||||
| with torch.no_grad(): | |||||
| return super().forward(inputs, **forward_params) | |||||
| def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: | def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: | ||||
| """process the prediction results | """process the prediction results | ||||
| @@ -0,0 +1,73 @@ | |||||
| import uuid | |||||
| from typing import Any, Dict, Union | |||||
| import numpy as np | |||||
| import torch | |||||
| from ...metainfo import Pipelines | |||||
| from ...models import Model | |||||
| from ...models.nlp import SbertForNLI | |||||
| from ...preprocessors import NLIPreprocessor | |||||
| from ...utils.constant import Tasks | |||||
| from ..base import Pipeline | |||||
| from ..builder import PIPELINES | |||||
| from ..outputs import OutputKeys | |||||
| __all__ = ['NLIPipeline'] | |||||
| @PIPELINES.register_module(Tasks.nli, module_name=Pipelines.nli) | |||||
| class NLIPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: Union[SbertForNLI, str], | |||||
| preprocessor: NLIPreprocessor = None, | |||||
| first_sequence='first_sequence', | |||||
| second_sequence='second_sequence', | |||||
| **kwargs): | |||||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||||
| Args: | |||||
| model (SbertForNLI): a model instance | |||||
| preprocessor (NLIPreprocessor): a preprocessor instance | |||||
| """ | |||||
| assert isinstance(model, str) or isinstance(model, SbertForNLI), \ | |||||
| 'model must be a single str or SbertForNLI' | |||||
| model = model if isinstance( | |||||
| model, SbertForNLI) else Model.from_pretrained(model) | |||||
| if preprocessor is None: | |||||
| preprocessor = NLIPreprocessor( | |||||
| model.model_dir, | |||||
| first_sequence=first_sequence, | |||||
| second_sequence=second_sequence) | |||||
| model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| assert len(model.id2label) > 0 | |||||
| def forward(self, inputs: Dict[str, Any], | |||||
| **forward_params) -> Dict[str, Any]: | |||||
| with torch.no_grad(): | |||||
| return super().forward(inputs, **forward_params) | |||||
| def postprocess(self, | |||||
| inputs: Dict[str, Any], | |||||
| topk: int = 5) -> Dict[str, str]: | |||||
| """process the prediction results | |||||
| Args: | |||||
| inputs (Dict[str, Any]): _description_ | |||||
| Returns: | |||||
| Dict[str, str]: the prediction results | |||||
| """ | |||||
| probs = inputs['probabilities'][0] | |||||
| num_classes = probs.shape[0] | |||||
| topk = min(topk, num_classes) | |||||
| top_indices = np.argpartition(probs, -topk)[-topk:] | |||||
| cls_ids = top_indices[np.argsort(probs[top_indices])] | |||||
| probs = probs[cls_ids].tolist() | |||||
| cls_names = [self.model.id2label[cid] for cid in cls_ids] | |||||
| return {OutputKeys.SCORES: probs, OutputKeys.LABELS: cls_names} | |||||
| @@ -1,12 +1,13 @@ | |||||
| from typing import Any, Dict, Union | from typing import Any, Dict, Union | ||||
| import numpy as np | import numpy as np | ||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.nlp import SbertForSentenceSimilarity | |||||
| from modelscope.preprocessors import SequenceClassificationPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| from ...metainfo import Pipelines | |||||
| from ...models import Model | from ...models import Model | ||||
| from ...models.nlp import SbertForSentenceSimilarity | |||||
| from ...preprocessors import SequenceClassificationPreprocessor | |||||
| from ...utils.constant import Tasks | |||||
| from ..base import Input, Pipeline | from ..base import Input, Pipeline | ||||
| from ..builder import PIPELINES | from ..builder import PIPELINES | ||||
| from ..outputs import OutputKeys | from ..outputs import OutputKeys | ||||
| @@ -19,8 +20,10 @@ __all__ = ['SentenceSimilarityPipeline'] | |||||
| class SentenceSimilarityPipeline(Pipeline): | class SentenceSimilarityPipeline(Pipeline): | ||||
| def __init__(self, | def __init__(self, | ||||
| model: Union[SbertForSentenceSimilarity, str], | |||||
| model: Union[Model, str], | |||||
| preprocessor: SequenceClassificationPreprocessor = None, | preprocessor: SequenceClassificationPreprocessor = None, | ||||
| first_sequence='first_sequence', | |||||
| second_sequence='second_sequence', | |||||
| **kwargs): | **kwargs): | ||||
| """use `model` and `preprocessor` to create a nlp sentence similarity pipeline for prediction | """use `model` and `preprocessor` to create a nlp sentence similarity pipeline for prediction | ||||
| @@ -36,14 +39,21 @@ class SentenceSimilarityPipeline(Pipeline): | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = SequenceClassificationPreprocessor( | preprocessor = SequenceClassificationPreprocessor( | ||||
| sc_model.model_dir, | sc_model.model_dir, | ||||
| first_sequence='first_sequence', | |||||
| second_sequence='second_sequence') | |||||
| first_sequence=first_sequence, | |||||
| second_sequence=second_sequence) | |||||
| sc_model.eval() | |||||
| super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | ||||
| assert hasattr(self.model, 'id2label'), \ | assert hasattr(self.model, 'id2label'), \ | ||||
| 'id2label map should be initalizaed in init function.' | 'id2label map should be initalizaed in init function.' | ||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||||
| def forward(self, inputs: Dict[str, Any], | |||||
| **forward_params) -> Dict[str, Any]: | |||||
| with torch.no_grad(): | |||||
| return super().forward(inputs, **forward_params) | |||||
| def postprocess(self, inputs: Dict[str, Any], | |||||
| **postprocess_params) -> Dict[str, str]: | |||||
| """process the prediction results | """process the prediction results | ||||
| Args: | Args: | ||||
| @@ -0,0 +1,78 @@ | |||||
| import os | |||||
| import uuid | |||||
| from typing import Any, Dict, Union | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | |||||
| from ...metainfo import Pipelines | |||||
| from ...models import Model | |||||
| from ...models.nlp import SbertForSentimentClassification | |||||
| from ...preprocessors import SentimentClassificationPreprocessor | |||||
| from ...utils.constant import Tasks | |||||
| from ..base import Input, Pipeline | |||||
| from ..builder import PIPELINES | |||||
| from ..outputs import OutputKeys | |||||
| __all__ = ['SentimentClassificationPipeline'] | |||||
| @PIPELINES.register_module( | |||||
| Tasks.sentiment_classification, | |||||
| module_name=Pipelines.sentiment_classification) | |||||
| class SentimentClassificationPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: Union[SbertForSentimentClassification, str], | |||||
| preprocessor: SentimentClassificationPreprocessor = None, | |||||
| first_sequence='first_sequence', | |||||
| second_sequence='second_sequence', | |||||
| **kwargs): | |||||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||||
| Args: | |||||
| model (SbertForSentimentClassification): a model instance | |||||
| preprocessor (SentimentClassificationPreprocessor): a preprocessor instance | |||||
| """ | |||||
| assert isinstance(model, str) or isinstance(model, SbertForSentimentClassification), \ | |||||
| 'model must be a single str or SbertForSentimentClassification' | |||||
| model = model if isinstance( | |||||
| model, | |||||
| SbertForSentimentClassification) else Model.from_pretrained(model) | |||||
| if preprocessor is None: | |||||
| preprocessor = SentimentClassificationPreprocessor( | |||||
| model.model_dir, | |||||
| first_sequence=first_sequence, | |||||
| second_sequence=second_sequence) | |||||
| model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| assert len(model.id2label) > 0 | |||||
| def forward(self, inputs: Dict[str, Any], | |||||
| **forward_params) -> Dict[str, Any]: | |||||
| with torch.no_grad(): | |||||
| return super().forward(inputs, **forward_params) | |||||
| def postprocess(self, | |||||
| inputs: Dict[str, Any], | |||||
| topk: int = 5) -> Dict[str, str]: | |||||
| """process the prediction results | |||||
| Args: | |||||
| inputs (Dict[str, Any]): _description_ | |||||
| Returns: | |||||
| Dict[str, str]: the prediction results | |||||
| """ | |||||
| probs = inputs['probabilities'][0] | |||||
| num_classes = probs.shape[0] | |||||
| topk = min(topk, num_classes) | |||||
| top_indices = np.argpartition(probs, -topk)[-topk:] | |||||
| cls_ids = top_indices[np.argsort(probs[top_indices])] | |||||
| probs = probs[cls_ids].tolist() | |||||
| cls_names = [self.model.id2label[cid] for cid in cls_ids] | |||||
| return {OutputKeys.SCORES: probs, OutputKeys.LABELS: cls_names} | |||||
| @@ -1,10 +1,12 @@ | |||||
| from typing import Dict, Optional, Union | |||||
| from typing import Any, Dict, Optional, Union | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.nlp import PalmForTextGeneration | |||||
| from modelscope.preprocessors import TextGenerationPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| import torch | |||||
| from ...metainfo import Pipelines | |||||
| from ...models import Model | |||||
| from ...models.nlp import PalmForTextGeneration | |||||
| from ...preprocessors import TextGenerationPreprocessor | |||||
| from ...utils.constant import Tasks | |||||
| from ..base import Pipeline, Tensor | from ..base import Pipeline, Tensor | ||||
| from ..builder import PIPELINES | from ..builder import PIPELINES | ||||
| from ..outputs import OutputKeys | from ..outputs import OutputKeys | ||||
| @@ -34,10 +36,17 @@ class TextGenerationPipeline(Pipeline): | |||||
| model.tokenizer, | model.tokenizer, | ||||
| first_sequence='sentence', | first_sequence='sentence', | ||||
| second_sequence=None) | second_sequence=None) | ||||
| model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | super().__init__(model=model, preprocessor=preprocessor, **kwargs) | ||||
| self.tokenizer = model.tokenizer | self.tokenizer = model.tokenizer | ||||
| def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: | |||||
| def forward(self, inputs: Dict[str, Any], | |||||
| **forward_params) -> Dict[str, Any]: | |||||
| with torch.no_grad(): | |||||
| return super().forward(inputs, **forward_params) | |||||
| def postprocess(self, inputs: Dict[str, Tensor], | |||||
| **postprocess_params) -> Dict[str, str]: | |||||
| """process the prediction results | """process the prediction results | ||||
| Args: | Args: | ||||
| @@ -1,10 +1,12 @@ | |||||
| from typing import Any, Dict, Optional, Union | from typing import Any, Dict, Optional, Union | ||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.nlp import StructBertForTokenClassification | |||||
| from modelscope.preprocessors import TokenClassifcationPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| import torch | |||||
| from ...metainfo import Pipelines | |||||
| from ...models import Model | |||||
| from ...models.nlp import SbertForTokenClassification | |||||
| from ...preprocessors import TokenClassifcationPreprocessor | |||||
| from ...utils.constant import Tasks | |||||
| from ..base import Pipeline, Tensor | from ..base import Pipeline, Tensor | ||||
| from ..builder import PIPELINES | from ..builder import PIPELINES | ||||
| from ..outputs import OutputKeys | from ..outputs import OutputKeys | ||||
| @@ -17,7 +19,7 @@ __all__ = ['WordSegmentationPipeline'] | |||||
| class WordSegmentationPipeline(Pipeline): | class WordSegmentationPipeline(Pipeline): | ||||
| def __init__(self, | def __init__(self, | ||||
| model: Union[StructBertForTokenClassification, str], | |||||
| model: Union[SbertForTokenClassification, str], | |||||
| preprocessor: Optional[TokenClassifcationPreprocessor] = None, | preprocessor: Optional[TokenClassifcationPreprocessor] = None, | ||||
| **kwargs): | **kwargs): | ||||
| """use `model` and `preprocessor` to create a nlp word segmentation pipeline for prediction | """use `model` and `preprocessor` to create a nlp word segmentation pipeline for prediction | ||||
| @@ -28,15 +30,23 @@ class WordSegmentationPipeline(Pipeline): | |||||
| """ | """ | ||||
| model = model if isinstance( | model = model if isinstance( | ||||
| model, | model, | ||||
| StructBertForTokenClassification) else Model.from_pretrained(model) | |||||
| SbertForTokenClassification) else Model.from_pretrained(model) | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = TokenClassifcationPreprocessor(model.model_dir) | preprocessor = TokenClassifcationPreprocessor(model.model_dir) | ||||
| model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | super().__init__(model=model, preprocessor=preprocessor, **kwargs) | ||||
| self.tokenizer = preprocessor.tokenizer | self.tokenizer = preprocessor.tokenizer | ||||
| self.config = model.config | self.config = model.config | ||||
| assert len(self.config.id2label) > 0 | |||||
| self.id2label = self.config.id2label | self.id2label = self.config.id2label | ||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||||
| def forward(self, inputs: Dict[str, Any], | |||||
| **forward_params) -> Dict[str, Any]: | |||||
| with torch.no_grad(): | |||||
| return super().forward(inputs, **forward_params) | |||||
| def postprocess(self, inputs: Dict[str, Any], | |||||
| **postprocess_params) -> Dict[str, str]: | |||||
| """process the prediction results | """process the prediction results | ||||
| Args: | Args: | ||||
| @@ -5,7 +5,9 @@ from modelscope.utils.constant import Tasks | |||||
| class OutputKeys(object): | class OutputKeys(object): | ||||
| SCORES = 'scores' | SCORES = 'scores' | ||||
| LABEL = 'label' | |||||
| LABELS = 'labels' | LABELS = 'labels' | ||||
| LABEL_POS = 'label_pos' | |||||
| POSES = 'poses' | POSES = 'poses' | ||||
| CAPTION = 'caption' | CAPTION = 'caption' | ||||
| BOXES = 'boxes' | BOXES = 'boxes' | ||||
| @@ -16,6 +18,8 @@ class OutputKeys(object): | |||||
| OUTPUT_PCM = 'output_pcm' | OUTPUT_PCM = 'output_pcm' | ||||
| IMG_EMBEDDING = 'img_embedding' | IMG_EMBEDDING = 'img_embedding' | ||||
| TEXT_EMBEDDING = 'text_embedding' | TEXT_EMBEDDING = 'text_embedding' | ||||
| RESPONSE = 'response' | |||||
| PREDICTION = 'prediction' | |||||
| TASK_OUTPUTS = { | TASK_OUTPUTS = { | ||||
| @@ -119,6 +123,13 @@ TASK_OUTPUTS = { | |||||
| # } | # } | ||||
| Tasks.sentence_similarity: [OutputKeys.SCORES, OutputKeys.LABELS], | Tasks.sentence_similarity: [OutputKeys.SCORES, OutputKeys.LABELS], | ||||
| # sentiment classification result for single sample | |||||
| # { | |||||
| # "labels": ["happy", "sad", "calm", "angry"], | |||||
| # "scores": [0.9, 0.1, 0.05, 0.05] | |||||
| # } | |||||
| Tasks.sentiment_classification: [OutputKeys.SCORES, OutputKeys.LABELS], | |||||
| # zero-shot classification result for single sample | # zero-shot classification result for single sample | ||||
| # { | # { | ||||
| # "scores": [0.9, 0.1, 0.05, 0.05] | # "scores": [0.9, 0.1, 0.05, 0.05] | ||||
| @@ -126,6 +137,39 @@ TASK_OUTPUTS = { | |||||
| # } | # } | ||||
| Tasks.zero_shot_classification: [OutputKeys.SCORES, OutputKeys.LABELS], | Tasks.zero_shot_classification: [OutputKeys.SCORES, OutputKeys.LABELS], | ||||
| # nli result for single sample | |||||
| # { | |||||
| # "labels": ["happy", "sad", "calm", "angry"], | |||||
| # "scores": [0.9, 0.1, 0.05, 0.05] | |||||
| # } | |||||
| Tasks.nli: [OutputKeys.SCORES, OutputKeys.LABELS], | |||||
| # {'pred': array([2.62349960e-03, 4.12110658e-03, 4.12748595e-05, 3.77560973e-05, | |||||
| # 1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04, | |||||
| # 6.45841064e-05, 1.15997791e-04, 5.11605394e-05, 9.87020373e-01, | |||||
| # 2.66957268e-05, 4.72324500e-05, 9.74208378e-05, 4.18022355e-05, | |||||
| # 2.97343540e-05, 5.81317654e-05, 5.44203431e-05, 6.28319322e-05, | |||||
| # 7.34537680e-05, 6.61411541e-05, 3.62534920e-05, 8.58885178e-05, | |||||
| # 8.24327726e-05, 4.66077945e-05, 5.32869453e-05, 4.16190960e-05, | |||||
| # 5.97518992e-05, 3.92273068e-05, 3.44069012e-05, 9.92335918e-05, | |||||
| # 9.25978165e-05, 6.26462061e-05, 3.32317031e-05, 1.32061413e-03, | |||||
| # 2.01607945e-05, 3.36636294e-05, 3.99156743e-05, 5.84108493e-05, | |||||
| # 2.53432900e-05, 4.95731190e-04, 2.64443643e-05, 4.46992999e-05, | |||||
| # 2.42672231e-05, 4.75615161e-05, 2.66230145e-05, 4.00083954e-05, | |||||
| # 2.90536875e-04, 4.23891543e-05, 8.63691166e-05, 4.98188965e-05, | |||||
| # 3.47019341e-05, 4.52718523e-05, 4.20905781e-05, 5.50173208e-05, | |||||
| # 4.92360487e-05, 3.56021264e-05, 2.13957210e-05, 6.17428886e-05, | |||||
| # 1.43893281e-04, 7.32152112e-05, 2.91354867e-04, 2.46623786e-05, | |||||
| # 3.61441926e-05, 3.38475402e-05, 3.44323053e-05, 5.70138109e-05, | |||||
| # 4.31488479e-05, 4.94503947e-05, 4.30105974e-05, 1.00963116e-04, | |||||
| # 2.82062047e-05, 1.15582036e-04, 4.48261271e-05, 3.99339879e-05, | |||||
| # 7.27692823e-05], dtype=float32), 'label_pos': array([11]), 'label': 'lost_or_stolen_card'} | |||||
| Tasks.dialog_intent_prediction: | |||||
| [OutputKeys.PREDICTION, OutputKeys.LABEL_POS, OutputKeys.LABEL], | |||||
| # sys : ['you', 'are', 'welcome', '.', 'have', 'a', 'great', 'day', '!'] | |||||
| Tasks.dialog_modeling: [OutputKeys.RESPONSE], | |||||
| # ============ audio tasks =================== | # ============ audio tasks =================== | ||||
| # audio processed for single file in PCM format | # audio processed for single file in PCM format | ||||
| @@ -11,6 +11,8 @@ try: | |||||
| from .audio import LinearAECAndFbank | from .audio import LinearAECAndFbank | ||||
| from .multi_modal import * # noqa F403 | from .multi_modal import * # noqa F403 | ||||
| from .nlp import * # noqa F403 | from .nlp import * # noqa F403 | ||||
| from .space.dialog_intent_prediction_preprocessor import * # noqa F403 | |||||
| from .space.dialog_modeling_preprocessor import * # noqa F403 | |||||
| except ModuleNotFoundError as e: | except ModuleNotFoundError as e: | ||||
| if str(e) == "No module named 'tensorflow'": | if str(e) == "No module named 'tensorflow'": | ||||
| pass | pass | ||||
| @@ -5,15 +5,16 @@ from typing import Any, Dict, Union | |||||
| from transformers import AutoTokenizer | from transformers import AutoTokenizer | ||||
| from modelscope.metainfo import Preprocessors | |||||
| from modelscope.utils.constant import Fields, InputFields | |||||
| from modelscope.utils.type_assert import type_assert | |||||
| from ..metainfo import Models, Preprocessors | |||||
| from ..utils.constant import Fields, InputFields | |||||
| from ..utils.type_assert import type_assert | |||||
| from .base import Preprocessor | from .base import Preprocessor | ||||
| from .builder import PREPROCESSORS | from .builder import PREPROCESSORS | ||||
| __all__ = [ | __all__ = [ | ||||
| 'Tokenize', 'SequenceClassificationPreprocessor', | 'Tokenize', 'SequenceClassificationPreprocessor', | ||||
| 'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor', | 'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor', | ||||
| 'NLIPreprocessor', 'SentimentClassificationPreprocessor', | |||||
| 'FillMaskPreprocessor', 'ZeroShotClassificationPreprocessor' | 'FillMaskPreprocessor', 'ZeroShotClassificationPreprocessor' | ||||
| ] | ] | ||||
| @@ -32,6 +33,140 @@ class Tokenize(Preprocessor): | |||||
| return data | return data | ||||
| @PREPROCESSORS.register_module( | |||||
| Fields.nlp, module_name=Preprocessors.nli_tokenizer) | |||||
| class NLIPreprocessor(Preprocessor): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| from sofa import SbertTokenizer | |||||
| self.model_dir: str = model_dir | |||||
| self.first_sequence: str = kwargs.pop('first_sequence', | |||||
| 'first_sequence') | |||||
| self.second_sequence = kwargs.pop('second_sequence', 'second_sequence') | |||||
| self.sequence_length = kwargs.pop('sequence_length', 128) | |||||
| self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) | |||||
| @type_assert(object, tuple) | |||||
| def __call__(self, data: tuple) -> Dict[str, Any]: | |||||
| """process the raw input data | |||||
| Args: | |||||
| data (tuple): [sentence1, sentence2] | |||||
| sentence1 (str): a sentence | |||||
| Example: | |||||
| 'you are so handsome.' | |||||
| sentence2 (str): a sentence | |||||
| Example: | |||||
| 'you are so beautiful.' | |||||
| Returns: | |||||
| Dict[str, Any]: the preprocessed data | |||||
| """ | |||||
| sentence1, sentence2 = data | |||||
| new_data = { | |||||
| self.first_sequence: sentence1, | |||||
| self.second_sequence: sentence2 | |||||
| } | |||||
| # preprocess the data for the model input | |||||
| rst = { | |||||
| 'id': [], | |||||
| 'input_ids': [], | |||||
| 'attention_mask': [], | |||||
| 'token_type_ids': [] | |||||
| } | |||||
| max_seq_length = self.sequence_length | |||||
| text_a = new_data[self.first_sequence] | |||||
| text_b = new_data[self.second_sequence] | |||||
| feature = self.tokenizer( | |||||
| text_a, | |||||
| text_b, | |||||
| padding=False, | |||||
| truncation=True, | |||||
| max_length=max_seq_length) | |||||
| rst['id'].append(new_data.get('id', str(uuid.uuid4()))) | |||||
| rst['input_ids'].append(feature['input_ids']) | |||||
| rst['attention_mask'].append(feature['attention_mask']) | |||||
| rst['token_type_ids'].append(feature['token_type_ids']) | |||||
| return rst | |||||
| @PREPROCESSORS.register_module( | |||||
| Fields.nlp, module_name=Preprocessors.sen_cls_tokenizer) | |||||
| class SentimentClassificationPreprocessor(Preprocessor): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| from sofa import SbertTokenizer | |||||
| self.model_dir: str = model_dir | |||||
| self.first_sequence: str = kwargs.pop('first_sequence', | |||||
| 'first_sequence') | |||||
| self.second_sequence = kwargs.pop('second_sequence', 'second_sequence') | |||||
| self.sequence_length = kwargs.pop('sequence_length', 128) | |||||
| self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) | |||||
| @type_assert(object, str) | |||||
| def __call__(self, data: str) -> Dict[str, Any]: | |||||
| """process the raw input data | |||||
| Args: | |||||
| data (str): a sentence | |||||
| Example: | |||||
| 'you are so handsome.' | |||||
| Returns: | |||||
| Dict[str, Any]: the preprocessed data | |||||
| """ | |||||
| new_data = {self.first_sequence: data} | |||||
| # preprocess the data for the model input | |||||
| rst = { | |||||
| 'id': [], | |||||
| 'input_ids': [], | |||||
| 'attention_mask': [], | |||||
| 'token_type_ids': [] | |||||
| } | |||||
| max_seq_length = self.sequence_length | |||||
| text_a = new_data[self.first_sequence] | |||||
| text_b = new_data.get(self.second_sequence, None) | |||||
| feature = self.tokenizer( | |||||
| text_a, | |||||
| text_b, | |||||
| padding='max_length', | |||||
| truncation=True, | |||||
| max_length=max_seq_length) | |||||
| rst['id'].append(new_data.get('id', str(uuid.uuid4()))) | |||||
| rst['input_ids'].append(feature['input_ids']) | |||||
| rst['attention_mask'].append(feature['attention_mask']) | |||||
| rst['token_type_ids'].append(feature['token_type_ids']) | |||||
| return rst | |||||
| @PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
| Fields.nlp, module_name=Preprocessors.bert_seq_cls_tokenizer) | Fields.nlp, module_name=Preprocessors.bert_seq_cls_tokenizer) | ||||
| class SequenceClassificationPreprocessor(Preprocessor): | class SequenceClassificationPreprocessor(Preprocessor): | ||||
| @@ -178,7 +313,6 @@ class TextGenerationPreprocessor(Preprocessor): | |||||
| rst['input_ids'].append(feature['input_ids']) | rst['input_ids'].append(feature['input_ids']) | ||||
| rst['attention_mask'].append(feature['attention_mask']) | rst['attention_mask'].append(feature['attention_mask']) | ||||
| return {k: torch.tensor(v) for k, v in rst.items()} | return {k: torch.tensor(v) for k, v in rst.items()} | ||||
| @@ -241,7 +375,7 @@ class FillMaskPreprocessor(Preprocessor): | |||||
| @PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
| Fields.nlp, module_name=Preprocessors.sbert_token_cls_tokenizer) | |||||
| Fields.nlp, module_name=Preprocessors.token_cls_tokenizer) | |||||
| class TokenClassifcationPreprocessor(Preprocessor): | class TokenClassifcationPreprocessor(Preprocessor): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| @@ -269,6 +403,7 @@ class TokenClassifcationPreprocessor(Preprocessor): | |||||
| Returns: | Returns: | ||||
| Dict[str, Any]: the preprocessed data | Dict[str, Any]: the preprocessed data | ||||
| """ | """ | ||||
| # preprocess the data for the model input | # preprocess the data for the model input | ||||
| text = data.replace(' ', '').strip() | text = data.replace(' ', '').strip() | ||||
| @@ -0,0 +1,57 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| from typing import Any, Dict | |||||
| import json | |||||
| from ...metainfo import Preprocessors | |||||
| from ...utils.config import Config | |||||
| from ...utils.constant import Fields, ModelFile | |||||
| from ...utils.type_assert import type_assert | |||||
| from ..base import Preprocessor | |||||
| from ..builder import PREPROCESSORS | |||||
| from .fields.intent_field import IntentBPETextField | |||||
| __all__ = ['DialogIntentPredictionPreprocessor'] | |||||
| @PREPROCESSORS.register_module( | |||||
| Fields.nlp, module_name=Preprocessors.dialog_intent_preprocessor) | |||||
| class DialogIntentPredictionPreprocessor(Preprocessor): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| self.model_dir: str = model_dir | |||||
| self.config = Config.from_file( | |||||
| os.path.join(self.model_dir, ModelFile.CONFIGURATION)) | |||||
| self.text_field = IntentBPETextField( | |||||
| self.model_dir, config=self.config) | |||||
| self.categories = None | |||||
| with open(os.path.join(self.model_dir, 'categories.json'), 'r') as f: | |||||
| self.categories = json.load(f) | |||||
| assert len(self.categories) == 77 | |||||
| @type_assert(object, str) | |||||
| def __call__(self, data: str) -> Dict[str, Any]: | |||||
| """process the raw input data | |||||
| Args: | |||||
| data (str): a sentence | |||||
| Example: | |||||
| 'you are so handsome.' | |||||
| Returns: | |||||
| Dict[str, Any]: the preprocessed data | |||||
| """ | |||||
| samples = self.text_field.preprocessor([data]) | |||||
| samples, _ = self.text_field.collate_fn_multi_turn(samples) | |||||
| return samples | |||||
| @@ -0,0 +1,51 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| from typing import Any, Dict | |||||
| from ...metainfo import Preprocessors | |||||
| from ...utils.config import Config | |||||
| from ...utils.constant import Fields, ModelFile | |||||
| from ...utils.type_assert import type_assert | |||||
| from ..base import Preprocessor | |||||
| from ..builder import PREPROCESSORS | |||||
| from .fields.gen_field import MultiWOZBPETextField | |||||
| __all__ = ['DialogModelingPreprocessor'] | |||||
| @PREPROCESSORS.register_module( | |||||
| Fields.nlp, module_name=Preprocessors.dialog_modeling_preprocessor) | |||||
| class DialogModelingPreprocessor(Preprocessor): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| self.model_dir: str = model_dir | |||||
| self.config = Config.from_file( | |||||
| os.path.join(self.model_dir, ModelFile.CONFIGURATION)) | |||||
| self.text_field = MultiWOZBPETextField( | |||||
| self.model_dir, config=self.config) | |||||
| @type_assert(object, Dict) | |||||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """process the raw input data | |||||
| Args: | |||||
| data (str): a sentence | |||||
| Example: | |||||
| 'you are so handsome.' | |||||
| Returns: | |||||
| Dict[str, Any]: the preprocessed data | |||||
| """ | |||||
| user_ids = self.text_field.get_ids(data['user_input']) | |||||
| data['user'] = user_ids | |||||
| return data | |||||
| @@ -0,0 +1,675 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import random | |||||
| from collections import OrderedDict | |||||
| from itertools import chain | |||||
| import numpy as np | |||||
| from ....utils.nlp.space import ontology, utils | |||||
| from ....utils.nlp.space.db_ops import MultiWozDB | |||||
| from ....utils.nlp.space.utils import list2np | |||||
| from ..tokenizer import Tokenizer | |||||
| class BPETextField(object): | |||||
| pad_token = '[PAD]' | |||||
| bos_token = '[BOS]' | |||||
| eos_token = '[EOS]' | |||||
| unk_token = '[UNK]' | |||||
| sos_u_token = '<sos_u>' | |||||
| eos_u_token = '<eos_u>' | |||||
| sos_b_token = '<sos_b>' | |||||
| eos_b_token = '<eos_b>' | |||||
| sos_d_token = '<sos_d>' | |||||
| eos_d_token = '<eos_d>' | |||||
| sos_a_token = '<sos_a>' | |||||
| eos_a_token = '<eos_a>' | |||||
| sos_db_token = '<sos_db>' | |||||
| eos_db_token = '<eos_db>' | |||||
| sos_r_token = '<sos_r>' | |||||
| eos_r_token = '<eos_r>' | |||||
| @property | |||||
| def bot_id(self): | |||||
| return 0 | |||||
| @property | |||||
| def user_id(self): | |||||
| return 1 | |||||
| @property | |||||
| def vocab_size(self): | |||||
| return self.tokenizer.vocab_size | |||||
| @property | |||||
| def num_specials(self): | |||||
| return len(self.tokenizer.special_tokens) | |||||
| @property | |||||
| def pad_id(self): | |||||
| return self.tokenizer.convert_tokens_to_ids([self.pad_token])[0] | |||||
| @property | |||||
| def bos_id(self): | |||||
| return self.tokenizer.convert_tokens_to_ids([self.bos_token])[0] | |||||
| @property | |||||
| def eos_id(self): | |||||
| return self.tokenizer.convert_tokens_to_ids([self.eos_token])[0] | |||||
| @property | |||||
| def unk_id(self): | |||||
| return self.tokenizer.convert_tokens_to_ids([self.unk_token])[0] | |||||
| @property | |||||
| def sos_u_id(self): | |||||
| return self.tokenizer.convert_tokens_to_ids([self.sos_u_token])[0] | |||||
| @property | |||||
| def eos_u_id(self): | |||||
| return self.tokenizer.convert_tokens_to_ids([self.eos_u_token])[0] | |||||
| @property | |||||
| def sos_b_id(self): | |||||
| return self.tokenizer.convert_tokens_to_ids([self.sos_b_token])[0] | |||||
| @property | |||||
| def eos_b_id(self): | |||||
| return self.tokenizer.convert_tokens_to_ids([self.eos_b_token])[0] | |||||
| @property | |||||
| def sos_db_id(self): | |||||
| return self.tokenizer.convert_tokens_to_ids([self.sos_db_token])[0] | |||||
| @property | |||||
| def eos_db_id(self): | |||||
| return self.tokenizer.convert_tokens_to_ids([self.eos_db_token])[0] | |||||
| @property | |||||
| def sos_a_id(self): | |||||
| return self.tokenizer.convert_tokens_to_ids([self.sos_a_token])[0] | |||||
| @property | |||||
| def eos_a_id(self): | |||||
| return self.tokenizer.convert_tokens_to_ids([self.eos_a_token])[0] | |||||
| @property | |||||
| def sos_r_id(self): | |||||
| return self.tokenizer.convert_tokens_to_ids([self.sos_r_token])[0] | |||||
| @property | |||||
| def eos_r_id(self): | |||||
| return self.tokenizer.convert_tokens_to_ids([self.eos_r_token])[0] | |||||
| @property | |||||
| def sos_d_id(self): | |||||
| return self.tokenizer.convert_tokens_to_ids([self.sos_d_token])[0] | |||||
| @property | |||||
| def eos_d_id(self): | |||||
| return self.tokenizer.convert_tokens_to_ids([self.eos_d_token])[0] | |||||
| def __init__(self, config): | |||||
| self.gpu = 0 | |||||
| self.tokenizer = None | |||||
| self.vocab = None | |||||
| self.db = None | |||||
| self.set_stats = {} | |||||
| self.prompt_num_for_understand = config.BPETextField.prompt_num_for_understand | |||||
| self.prompt_num_for_policy = config.BPETextField.prompt_num_for_policy | |||||
| self.understand_tokens = ontology.get_understand_tokens( | |||||
| self.prompt_num_for_understand) | |||||
| self.policy_tokens = ontology.get_policy_tokens( | |||||
| self.prompt_num_for_policy) | |||||
| self.with_query_bow = config.BPETextField.with_query_bow | |||||
| self.understand = config.BPETextField.understand | |||||
| self.policy = config.BPETextField.policy | |||||
| self.batch_size = config.Trainer.batch_size | |||||
| self.filtered = config.BPETextField.filtered | |||||
| self.max_len = config.BPETextField.max_len | |||||
| self.min_utt_len = config.BPETextField.min_utt_len | |||||
| self.max_utt_len = config.BPETextField.max_utt_len | |||||
| self.min_ctx_turn = config.BPETextField.min_ctx_turn | |||||
| self.max_ctx_turn = config.BPETextField.max_ctx_turn - 1 # subtract reply turn | |||||
| self.use_true_prev_bspn = config.Generator.use_true_prev_bspn | |||||
| self.use_true_prev_aspn = config.Generator.use_true_prev_aspn | |||||
| self.use_true_db_pointer = config.Generator.use_true_db_pointer | |||||
| self.use_true_prev_resp = config.Generator.use_true_prev_resp | |||||
| self.use_true_curr_bspn = config.Generator.use_true_curr_bspn | |||||
| self.use_true_curr_aspn = config.Generator.use_true_curr_aspn | |||||
| self.use_all_previous_context = config.Generator.use_all_previous_context | |||||
| self.use_true_bspn_for_ctr_eval = config.Generator.use_true_bspn_for_ctr_eval | |||||
| self.use_true_domain_for_ctr_eval = config.Generator.use_true_domain_for_ctr_eval | |||||
| def collate_fn_multi_turn(self, samples): | |||||
| batch_size = len(samples) | |||||
| batch = {} | |||||
| src = [sp['src'][-self.max_ctx_turn:] for sp in samples] | |||||
| query_token, src_token, src_pos, src_turn, src_role = [], [], [], [], [] | |||||
| for utts in src: | |||||
| query_token.append(utts[-1]) | |||||
| utt_lens = [len(utt) for utt in utts] | |||||
| # Token ids | |||||
| src_token.append(list(chain(*utts))[-self.max_len:]) | |||||
| # Position ids | |||||
| pos = [list(range(utt_len)) for utt_len in utt_lens] | |||||
| src_pos.append(list(chain(*pos))[-self.max_len:]) | |||||
| # Turn ids | |||||
| turn = [[len(utts) - i] * l for i, l in enumerate(utt_lens)] | |||||
| src_turn.append(list(chain(*turn))[-self.max_len:]) | |||||
| # Role ids | |||||
| role = [ | |||||
| [self.bot_id if (len(utts) - i) % 2 == 0 else self.user_id] * l | |||||
| for i, l in enumerate(utt_lens) | |||||
| ] | |||||
| src_role.append(list(chain(*role))[-self.max_len:]) | |||||
| # src sequence and tgt sequence should be padded separately,to make sure the first word is aligned | |||||
| src_token = list2np(src_token, padding=self.pad_id) | |||||
| src_pos = list2np(src_pos, padding=self.pad_id) | |||||
| src_turn = list2np(src_turn, padding=self.pad_id) | |||||
| src_role = list2np(src_role, padding=self.pad_id) | |||||
| batch['src_token'] = src_token | |||||
| batch['src_pos'] = src_pos | |||||
| batch['src_type'] = src_role | |||||
| batch['src_turn'] = src_turn | |||||
| batch['src_mask'] = (src_token != self.pad_id).astype('int64') | |||||
| if self.with_query_bow: | |||||
| query_token = list2np(query_token, padding=self.pad_id) | |||||
| batch['query_token'] = query_token | |||||
| batch['query_mask'] = (query_token != self.pad_id).astype('int64') | |||||
| if self.understand_ids and self.understand: | |||||
| understand = [self.understand_ids for _ in samples] | |||||
| understand_token = np.array(understand).astype('int64') | |||||
| batch['understand_token'] = understand_token | |||||
| batch['understand_mask'] = \ | |||||
| (understand_token != self.pad_id).astype('int64') | |||||
| if self.policy_ids and self.policy: | |||||
| policy = [self.policy_ids for _ in samples] | |||||
| policy_token = np.array(policy).astype('int64') | |||||
| batch['policy_token'] = policy_token | |||||
| batch['policy_mask'] = \ | |||||
| (policy_token != self.pad_id).astype('int64') | |||||
| if 'tgt' in samples[0]: | |||||
| tgt = [sp['tgt'] for sp in samples] | |||||
| # Token ids & Label ids | |||||
| tgt_token = list2np(tgt, padding=self.pad_id) | |||||
| # Position ids | |||||
| tgt_pos = np.zeros_like(tgt_token) | |||||
| tgt_pos[:] = np.arange(tgt_token.shape[1], dtype=tgt_token.dtype) | |||||
| # Turn ids | |||||
| tgt_turn = np.zeros_like(tgt_token) | |||||
| # Role ids | |||||
| tgt_role = np.full_like(tgt_token, self.bot_id) | |||||
| batch['tgt_token'] = tgt_token | |||||
| batch['tgt_pos'] = tgt_pos | |||||
| batch['tgt_type'] = tgt_role | |||||
| batch['tgt_turn'] = tgt_turn | |||||
| batch['tgt_mask'] = (tgt_token != self.pad_id).astype('int64') | |||||
| return batch, batch_size | |||||
| def _bucket_by_turn(self, encoded_data): | |||||
| turn_bucket = {} | |||||
| for dial in encoded_data: | |||||
| turn_len = len(dial) | |||||
| if turn_len not in turn_bucket: | |||||
| turn_bucket[turn_len] = [] | |||||
| turn_bucket[turn_len].append(dial) | |||||
| return OrderedDict(sorted(turn_bucket.items(), key=lambda i: i[0])) | |||||
| def _construct_mini_batch(self, data): | |||||
| all_batches = [] | |||||
| batch = [] | |||||
| for dial in data: | |||||
| batch.append(dial) | |||||
| if len(batch) == self.batch_size: | |||||
| # print('batch size: %d, batch num +1'%(len(batch))) | |||||
| all_batches.append(batch) | |||||
| batch = [] | |||||
| # if remainder > 1/2 batch_size, just put them in the previous batch, otherwise form a new batch | |||||
| # print('last batch size: %d, batch num +1'%(len(batch))) | |||||
| # if (len(batch) % len(cfg.cuda_device)) != 0: | |||||
| # batch = batch[:-(len(batch) % len(cfg.cuda_device))] | |||||
| # TODO deal with deleted data | |||||
| if self.gpu <= 1: | |||||
| if len(batch) > 0.5 * self.batch_size: | |||||
| all_batches.append(batch) | |||||
| elif len(all_batches): | |||||
| all_batches[-1].extend(batch) | |||||
| else: | |||||
| all_batches.append(batch) | |||||
| return all_batches | |||||
| def transpose_batch(self, batch): | |||||
| dial_batch = [] | |||||
| turn_num = len(batch[0]) | |||||
| for turn in range(turn_num): | |||||
| turn_l = {} | |||||
| for dial in batch: | |||||
| this_turn = dial[turn] | |||||
| for k in this_turn: | |||||
| if k not in turn_l: | |||||
| turn_l[k] = [] | |||||
| turn_l[k].append(this_turn[k]) | |||||
| dial_batch.append(turn_l) | |||||
| return dial_batch | |||||
| def get_eval_data(self, set_name='dev'): | |||||
| name_to_set = {'train': self.train, 'test': self.test, 'dev': self.dev} | |||||
| dial = name_to_set[set_name] | |||||
| if set_name not in self.set_stats: | |||||
| self.set_stats[set_name] = {} | |||||
| num_turns = 0 | |||||
| num_dials = len(dial) | |||||
| for d in dial: | |||||
| num_turns += len(d) | |||||
| self.set_stats[set_name]['num_turns'] = num_turns | |||||
| self.set_stats[set_name]['num_dials'] = num_dials | |||||
| return dial | |||||
| def get_nontranspose_data_iterator(self, all_batches): | |||||
| for i, batch in enumerate(all_batches): | |||||
| yield batch | |||||
| def get_data_iterator(self, all_batches): | |||||
| for i, batch in enumerate(all_batches): | |||||
| yield self.transpose_batch(batch) | |||||
| class MultiWOZBPETextField(BPETextField): | |||||
| def __init__(self, model_dir, config): | |||||
| super(MultiWOZBPETextField, self).__init__(config) | |||||
| import spacy | |||||
| self.nlp = spacy.load('en_core_web_sm') | |||||
| self.db = MultiWozDB( | |||||
| model_dir, { | |||||
| 'attraction': 'db/attraction_db_processed.json', | |||||
| 'hospital': 'db/hospital_db_processed.json', | |||||
| 'hotel': 'db/hotel_db_processed.json', | |||||
| 'police': 'db/police_db_processed.json', | |||||
| 'restaurant': 'db/restaurant_db_processed.json', | |||||
| 'taxi': 'db/taxi_db_processed.json', | |||||
| 'train': 'db/train_db_processed.json', | |||||
| }) | |||||
| self._build_vocab(model_dir) | |||||
| special_tokens = [ | |||||
| self.pad_token, self.bos_token, self.eos_token, self.unk_token | |||||
| ] | |||||
| special_tokens.extend(self.add_sepcial_tokens()) | |||||
| self.tokenizer = Tokenizer( | |||||
| vocab_path=os.path.join(model_dir, 'vocab.txt'), | |||||
| special_tokens=special_tokens, | |||||
| tokenizer_type=config.BPETextField.tokenizer_type) | |||||
| self.understand_ids = self.tokenizer.convert_tokens_to_ids( | |||||
| self.understand_tokens) | |||||
| self.policy_ids = self.tokenizer.convert_tokens_to_ids( | |||||
| self.policy_tokens) | |||||
| return | |||||
| def get_ids(self, data: str): | |||||
| result = [self.sos_u_id] + self.tokenizer.convert_tokens_to_ids( | |||||
| self.tokenizer.tokenize( | |||||
| self._get_convert_str(data))) + [self.eos_u_id] | |||||
| return result | |||||
| def inverse_transpose_turn(self, turn_list): | |||||
| """ | |||||
| eval, one dialog at a time | |||||
| """ | |||||
| dialogs = {} | |||||
| turn_num = len(turn_list) | |||||
| dial_id = turn_list[0]['dial_id'] | |||||
| dialogs[dial_id] = [] | |||||
| for turn_idx in range(turn_num): | |||||
| dial_turn = {} | |||||
| turn = turn_list[turn_idx] | |||||
| for key, value in turn.items(): | |||||
| if key == 'dial_id': | |||||
| continue | |||||
| if key == 'pointer' and self.db is not None: | |||||
| turn_domain = turn['turn_domain'][-1] | |||||
| value = self.db.pointerBack(value, turn_domain) | |||||
| dial_turn[key] = value | |||||
| dialogs[dial_id].append(dial_turn) | |||||
| return dialogs | |||||
| def inverse_transpose_batch(self, turn_batch_list): | |||||
| """ | |||||
| :param turn_batch_list: list of transpose dial batch | |||||
| """ | |||||
| dialogs = {} | |||||
| total_turn_num = len(turn_batch_list) | |||||
| # initialize | |||||
| for idx_in_batch, dial_id in enumerate(turn_batch_list[0]['dial_id']): | |||||
| dialogs[dial_id] = [] | |||||
| for turn_n in range(total_turn_num): | |||||
| dial_turn = {} | |||||
| turn_batch = turn_batch_list[turn_n] | |||||
| for key, v_list in turn_batch.items(): | |||||
| if key == 'dial_id': | |||||
| continue | |||||
| value = v_list[idx_in_batch] | |||||
| if key == 'pointer' and self.db is not None: | |||||
| turn_domain = turn_batch['turn_domain'][idx_in_batch][ | |||||
| -1] | |||||
| value = self.db.pointerBack(value, turn_domain) | |||||
| dial_turn[key] = value | |||||
| dialogs[dial_id].append(dial_turn) | |||||
| return dialogs | |||||
| def get_batches(self, set_name): | |||||
| """ | |||||
| compute dataset stats. | |||||
| """ | |||||
| global dia_count | |||||
| log_str = '' | |||||
| name_to_set = {'train': self.train, 'test': self.test, 'dev': self.dev} | |||||
| dial = name_to_set[set_name] | |||||
| turn_bucket = self._bucket_by_turn(dial) | |||||
| # self._shuffle_turn_bucket(turn_bucket) | |||||
| all_batches = [] | |||||
| if set_name not in self.set_stats: | |||||
| self.set_stats[set_name] = {} | |||||
| num_training_steps = 0 | |||||
| num_turns = 0 | |||||
| num_dials = 0 | |||||
| for k in turn_bucket: | |||||
| if set_name != 'test' and k == 1 or k >= 17: | |||||
| continue | |||||
| batches = self._construct_mini_batch(turn_bucket[k]) | |||||
| try: | |||||
| log_str += 'turn num:%d, dial num: %d, batch num: %d last batch len: %d\n' % ( | |||||
| k, len(turn_bucket[k]), len(batches), len(batches[-1])) | |||||
| except Exception: | |||||
| log_str += 'turn num:%d, dial num: %d, batch num: %d last batch len: %d\n' % ( | |||||
| k, len(turn_bucket[k]), len(batches), 0.0) | |||||
| # print("turn num:%d, dial num:v%d, batch num: %d, "%(k, len(turn_bucket[k]), len(batches))) | |||||
| num_training_steps += k * len(batches) | |||||
| num_turns += k * len(turn_bucket[k]) | |||||
| num_dials += len(turn_bucket[k]) | |||||
| all_batches += batches | |||||
| log_str += 'total batch num: %d\n' % len(all_batches) | |||||
| # print('total batch num: %d'%len(all_batches)) | |||||
| # print('dialog count: %d'%dia_count) | |||||
| # return all_batches | |||||
| # log stats | |||||
| # logging.info(log_str) | |||||
| # cfg.num_training_steps = num_training_steps * cfg.epoch_num | |||||
| self.set_stats[set_name][ | |||||
| 'num_training_steps_per_epoch'] = num_training_steps # turn-level steps | |||||
| self.set_stats[set_name]['num_turns'] = num_turns | |||||
| self.set_stats[set_name]['num_dials'] = num_dials | |||||
| if set_name == 'train': | |||||
| random.shuffle(all_batches) | |||||
| return all_batches | |||||
| def add_sepcial_tokens(self): | |||||
| """ | |||||
| add special tokens to gpt tokenizer | |||||
| serves a similar role of Vocab.construt() | |||||
| make a dict of special tokens | |||||
| """ | |||||
| special_tokens = [] | |||||
| prompt_tokens = self.understand_tokens + self.policy_tokens | |||||
| special_tokens.extend( | |||||
| ontology.get_special_tokens(other_tokens=prompt_tokens)) | |||||
| for word in ontology.all_domains + ['general']: | |||||
| word = '[' + word + ']' | |||||
| special_tokens.append(word) | |||||
| for word in ontology.all_acts: | |||||
| word = '[' + word + ']' | |||||
| special_tokens.append(word) | |||||
| for word in self.vocab._word2idx.keys(): | |||||
| if word.startswith('[value_') and word.endswith(']'): | |||||
| special_tokens.append(word) | |||||
| return special_tokens | |||||
| def _build_vocab(self, model_dir: str): | |||||
| self.vocab = utils.MultiWOZVocab(3000) | |||||
| vp = os.path.join('{}/vocab'.format(model_dir)) | |||||
| self.vocab.load_vocab(vp) | |||||
| return self.vocab.vocab_size | |||||
| def _get_convert_str(self, sent): | |||||
| assert isinstance(sent, str) | |||||
| return ' '.join([ | |||||
| self.tokenizer.spec_convert_dict.get(tok, tok) | |||||
| for tok in sent.split() | |||||
| ]) | |||||
| def bspan_to_DBpointer(self, bspan, turn_domain): | |||||
| constraint_dict = self.bspan_to_constraint_dict(bspan) | |||||
| # print(constraint_dict) | |||||
| matnums = self.db.get_match_num(constraint_dict) | |||||
| match_dom = turn_domain[0] if len(turn_domain) == 1 else turn_domain[1] | |||||
| match_dom = match_dom[1:-1] if match_dom.startswith('[') else match_dom | |||||
| match = matnums[match_dom] | |||||
| # vector = self.db.addDBPointer(match_dom, match) | |||||
| vector = self.db.addDBIndicator(match_dom, match) | |||||
| return vector | |||||
| def bspan_to_constraint_dict(self, bspan, bspn_mode='bspn'): | |||||
| """ | |||||
| ['[hotel]', 'pricerange', 'cheap', 'type', 'hotel'] -> {'hotel': {'pricerange': 'cheap', 'type': 'hotel'}} | |||||
| """ | |||||
| bspan = bspan.split() if isinstance(bspan, str) else bspan | |||||
| constraint_dict = {} | |||||
| domain = None | |||||
| conslen = len(bspan) | |||||
| for idx, cons in enumerate(bspan): | |||||
| cons = self.vocab.decode(cons) if type(cons) is not str else cons | |||||
| if cons == '<eos_b>': | |||||
| break | |||||
| if '[' in cons: | |||||
| if cons[1:-1] not in ontology.all_domains: | |||||
| continue | |||||
| domain = cons[1:-1] | |||||
| elif cons in ontology.get_slot: | |||||
| if domain is None: | |||||
| continue | |||||
| if cons == 'people': | |||||
| # handle confusion of value name "people's portraits..." and slot people | |||||
| try: | |||||
| ns = bspan[idx + 1] | |||||
| ns = self.vocab.decode(ns) if type( | |||||
| ns) is not str else ns | |||||
| if ns == "'s": | |||||
| continue | |||||
| except Exception: | |||||
| continue | |||||
| if not constraint_dict.get(domain): | |||||
| constraint_dict[domain] = {} | |||||
| if bspn_mode == 'bsdx': | |||||
| constraint_dict[domain][cons] = 1 | |||||
| continue | |||||
| vidx = idx + 1 | |||||
| if vidx == conslen: | |||||
| break | |||||
| vt_collect = [] | |||||
| vt = bspan[vidx] | |||||
| vt = self.vocab.decode(vt) if type(vt) is not str else vt | |||||
| while vidx < conslen and vt != '<eos_b>' and '[' not in vt and vt not in ontology.get_slot: | |||||
| vt_collect.append(vt) | |||||
| vidx += 1 | |||||
| if vidx == conslen: | |||||
| break | |||||
| vt = bspan[vidx] | |||||
| vt = self.vocab.decode(vt) if type(vt) is not str else vt | |||||
| if vt_collect: | |||||
| constraint_dict[domain][cons] = ' '.join(vt_collect) | |||||
| return constraint_dict | |||||
| def convert_batch_turn(self, turn_batch, pv_batch, first_turn=False): | |||||
| """ | |||||
| convert the current and the last turn | |||||
| concat [U_0,R_0,...,U_{t-1}, R_{t-1}, U_t, B_t, A_t, R_t] | |||||
| firts turn: [U_t, B_t, A_t, R_t] | |||||
| try: [user, bspn, db, aspn, resp] | |||||
| """ | |||||
| inputs = [] | |||||
| if first_turn: | |||||
| batch_zipped = zip(turn_batch['user'], turn_batch['bspn'], | |||||
| turn_batch['db'], turn_batch['aspn'], | |||||
| turn_batch['resp']) | |||||
| for u, b, db, a, r in batch_zipped: | |||||
| if self.use_true_curr_bspn: | |||||
| src = [u + b + db] | |||||
| tgt = a + r | |||||
| else: | |||||
| src = [u] | |||||
| tgt = b + db + a + r | |||||
| inputs.append({'src': src, 'tgt': tgt}) | |||||
| pv = [src[-1], tgt] | |||||
| pv_batch.append(pv) | |||||
| else: | |||||
| batch_zipped = zip(pv_batch, turn_batch['user'], | |||||
| turn_batch['bspn'], turn_batch['db'], | |||||
| turn_batch['aspn'], turn_batch['resp']) | |||||
| for i, (pv, u, b, db, a, r) in enumerate(batch_zipped): | |||||
| if self.use_true_curr_bspn: | |||||
| src = pv + [u + b + db] | |||||
| tgt = a + r | |||||
| else: | |||||
| src = pv + [u] | |||||
| tgt = b + db + a + r | |||||
| inputs.append({'src': src, 'tgt': tgt}) | |||||
| pv = [src[-1], tgt] | |||||
| pv_batch[i].extend(pv) | |||||
| return inputs, pv_batch | |||||
| def wrap_result_lm(self, result_dict, eos_syntax=None): | |||||
| results = [] | |||||
| eos_syntax = ontology.eos_tokens if not eos_syntax else eos_syntax | |||||
| sos_syntax = ontology.sos_tokens | |||||
| # ground truth bs, as, ds.. generate response | |||||
| field = [ | |||||
| 'dial_id', 'turn_num', 'user', 'bspn_gen', 'bsdx', 'resp_gen', | |||||
| 'resp', 'aspn_gen', 'aspn', 'dspn_gen', 'dspn', 'bspn', 'pointer', | |||||
| 'qspn_gen', 'qspn' | |||||
| ] | |||||
| for dial_id, turns in result_dict.items(): | |||||
| entry = {'dial_id': dial_id, 'trun_num': len(turns)} | |||||
| for f in field[2:]: | |||||
| entry[f] = '' # TODO ??? | |||||
| results.append(entry) | |||||
| for turn_idx, turn in enumerate(turns): | |||||
| entry = {'dial_id': dial_id} | |||||
| for key in field: | |||||
| if key in ['dial_id']: | |||||
| continue | |||||
| v = turn.get(key, '') | |||||
| if key == 'turn_domain': | |||||
| v = ' '.join(v) | |||||
| if key in eos_syntax and v != '': | |||||
| # remove eos tokens | |||||
| v = self.tokenizer.decode(v) | |||||
| v = v.split() | |||||
| # remove eos/sos in span | |||||
| if eos_syntax[key] in v: | |||||
| v.remove(eos_syntax[key]) | |||||
| if sos_syntax[key] in v: | |||||
| v.remove(sos_syntax[key]) | |||||
| v = ' '.join(v) | |||||
| else: | |||||
| pass # v = v | |||||
| entry[key] = v | |||||
| results.append(entry) | |||||
| return results, field | |||||
| def convert_turn_eval(self, turn, pv_turn, first_turn=False): | |||||
| """ | |||||
| input: [all previous ubar, U_t, B_t, A_t] predict R_t | |||||
| firts turn: [U_t, B_t, A_t] predict R_t | |||||
| regarding the context, all previous ubar is too slow, try the previous ubar | |||||
| """ | |||||
| inputs = {} | |||||
| context_list = [] | |||||
| prompt_id = None | |||||
| if self.use_true_curr_bspn: | |||||
| if self.use_true_curr_aspn: # only predict resp | |||||
| context_list = ['user', 'bspn', 'db', 'aspn'] | |||||
| prompt_id = self.sos_r_id | |||||
| else: # predicted aspn | |||||
| context_list = ['user', 'bspn', 'db'] | |||||
| prompt_id = self.sos_a_id | |||||
| else: # predict bspn aspn resp. db are not predicted. this part tbd. | |||||
| context_list = ['user'] | |||||
| prompt_id = self.sos_b_id | |||||
| if first_turn: | |||||
| context = [] | |||||
| for c in context_list: | |||||
| context += turn[c] | |||||
| inputs['src'] = [context] | |||||
| inputs['labels'] = [context] | |||||
| else: | |||||
| context = [] | |||||
| for c in context_list: | |||||
| context += turn[c] | |||||
| if self.use_true_curr_bspn: | |||||
| pv_context = pv_turn['labels'] + [ | |||||
| pv_turn['aspn'] + pv_turn['resp'] | |||||
| ] | |||||
| else: | |||||
| pv_info = pv_turn['bspn'] + pv_turn['db'] + pv_turn[ | |||||
| 'aspn'] + pv_turn['resp'] | |||||
| pv_context = pv_turn['labels'] + [pv_info] | |||||
| # prompt response, add sos_r | |||||
| inputs['src'] = pv_context + [context] | |||||
| if self.use_all_previous_context: | |||||
| inputs['labels'] = pv_context + [ | |||||
| context | |||||
| ] # use all previous ubar history | |||||
| else: | |||||
| inputs['labels'] = [context] # use previous turn | |||||
| return inputs, prompt_id | |||||
| @@ -0,0 +1,668 @@ | |||||
| from __future__ import (absolute_import, division, print_function, | |||||
| unicode_literals) | |||||
| import collections | |||||
| import logging | |||||
| import os | |||||
| import sys | |||||
| import unicodedata | |||||
| import json | |||||
| import regex as re | |||||
| def clean_string(string): | |||||
| replace_mp = { | |||||
| ' - ': '-', | |||||
| " ' ": "'", | |||||
| " n't": "n't", | |||||
| " 'm": "'m", | |||||
| ' do not': " don't", | |||||
| " 's": "'s", | |||||
| " 've": "'ve", | |||||
| " 're": "'re" | |||||
| } | |||||
| for k, v in replace_mp.items(): | |||||
| string = string.replace(k, v) | |||||
| return string | |||||
| class Tokenizer(object): | |||||
| def __init__(self, vocab_path, special_tokens=[], tokenizer_type='Bert'): | |||||
| self.tokenizer_type = tokenizer_type | |||||
| if tokenizer_type == 'Bert': | |||||
| self.spec_convert_dict = { | |||||
| '[BOS]': '[unused0]', | |||||
| '[EOS]': '[unused1]' | |||||
| } | |||||
| for token in special_tokens: | |||||
| if token not in self.spec_convert_dict and token not in [ | |||||
| '[PAD]', '[UNK]' | |||||
| ]: | |||||
| self.spec_convert_dict[ | |||||
| token] = f'[unused{len(self.spec_convert_dict)}]' | |||||
| self.spec_revert_dict = { | |||||
| v: k | |||||
| for k, v in self.spec_convert_dict.items() | |||||
| } | |||||
| special_tokens = [ | |||||
| self.spec_convert_dict.get(tok, tok) for tok in special_tokens | |||||
| ] | |||||
| self.special_tokens = ('[UNK]', '[SEP]', '[PAD]', '[CLS]', | |||||
| '[MASK]') | |||||
| self.special_tokens += tuple(x for x in special_tokens | |||||
| if x not in self.special_tokens) | |||||
| self._tokenizer = BertTokenizer( | |||||
| vocab_path, never_split=self.special_tokens) | |||||
| for tok in self.special_tokens: | |||||
| assert tok in self._tokenizer.vocab, f"special token '{tok}' is not in the vocabulary" | |||||
| self.vocab_size = len(self._tokenizer.vocab) | |||||
| elif tokenizer_type == 'GPT2': | |||||
| self.spec_convert_dict = {'[UNK]': '<unk>'} | |||||
| self.spec_revert_dict = { | |||||
| v: k | |||||
| for k, v in self.spec_convert_dict.items() | |||||
| } | |||||
| special_tokens = [ | |||||
| tok for tok in special_tokens | |||||
| if tok not in self.spec_convert_dict | |||||
| ] | |||||
| vocab_file = os.path.join(vocab_path, 'vocab.json') | |||||
| merges_file = os.path.join(vocab_path, 'merges.txt') | |||||
| self._tokenizer = GPT2Tokenizer( | |||||
| vocab_file, merges_file, special_tokens=special_tokens) | |||||
| self.num_specials = len(special_tokens) | |||||
| self.vocab_size = len(self._tokenizer) | |||||
| else: | |||||
| raise ValueError | |||||
| def tokenize(self, text): | |||||
| return self._tokenizer.tokenize(text) | |||||
| def convert_tokens_to_ids(self, tokens): | |||||
| if self.tokenizer_type == 'Bert': | |||||
| tokens = [self.spec_convert_dict.get(tok, tok) for tok in tokens] | |||||
| ids = self._tokenizer.convert_tokens_to_ids(tokens) | |||||
| return ids | |||||
| else: | |||||
| tokens = [self.spec_convert_dict.get(tok, tok) for tok in tokens] | |||||
| ids = self._tokenizer.convert_tokens_to_ids(tokens) | |||||
| ids = [(i + self.num_specials) % self.vocab_size for i in ids] | |||||
| return ids | |||||
| def convert_ids_to_tokens(self, ids): | |||||
| if self.tokenizer_type == 'Bert': | |||||
| tokens = self._tokenizer.convert_ids_to_tokens(ids) | |||||
| tokens = [self.spec_revert_dict.get(tok, tok) for tok in tokens] | |||||
| return tokens | |||||
| else: | |||||
| ids = [(i - self.num_specials) % self.vocab_size for i in ids] | |||||
| tokens = self._tokenizer.convert_ids_to_tokens(ids) | |||||
| tokens = [self.spec_revert_dict.get(tok, tok) for tok in tokens] | |||||
| return tokens | |||||
| def decode(self, ids, ignore_tokens=[]): | |||||
| tokens = self.convert_ids_to_tokens(ids) | |||||
| if len(ignore_tokens) > 0: | |||||
| ignore_tokens = set(ignore_tokens) | |||||
| tokens = [tok for tok in tokens if tok not in ignore_tokens] | |||||
| if self.tokenizer_type == 'Bert': | |||||
| string = ' '.join(tokens).replace(' ##', '') | |||||
| else: | |||||
| string = ''.join(tokens) | |||||
| string = bytearray([ | |||||
| self._tokenizer.byte_decoder[c] for c in string | |||||
| ]).decode('utf-8') | |||||
| string = clean_string(string) | |||||
| return string | |||||
| # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Tokenization classes.""" | |||||
| logger = logging.getLogger(__name__) | |||||
| def load_vocab(vocab_file): | |||||
| """Loads a vocabulary file into a dictionary.""" | |||||
| vocab = collections.OrderedDict() | |||||
| index = 0 | |||||
| with open(vocab_file, 'r', encoding='utf-8') as reader: | |||||
| while True: | |||||
| token = reader.readline() | |||||
| if not token: | |||||
| break | |||||
| token = token.strip() | |||||
| vocab[token] = index | |||||
| index += 1 | |||||
| return vocab | |||||
| def whitespace_tokenize(text): | |||||
| """Runs basic whitespace cleaning and splitting on a piece of text.""" | |||||
| text = text.strip() | |||||
| if not text: | |||||
| return [] | |||||
| tokens = text.split() | |||||
| return tokens | |||||
| class BertTokenizer(object): | |||||
| """Runs end-to-end tokenization: punctuation splitting + wordpiece""" | |||||
| def __init__(self, | |||||
| vocab_file, | |||||
| do_lower_case=True, | |||||
| max_len=None, | |||||
| do_basic_tokenize=True, | |||||
| never_split=('[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]')): | |||||
| """Constructs a BertTokenizer. | |||||
| Args: | |||||
| vocab_file: Path to a one-wordpiece-per-line vocabulary file | |||||
| do_lower_case: Whether to lower case the input | |||||
| Only has an effect when do_wordpiece_only=False | |||||
| do_basic_tokenize: Whether to do basic tokenization before wordpiece. | |||||
| max_len: An artificial maximum length to truncate tokenized sequences to; | |||||
| Effective maximum length is always the minimum of this | |||||
| value (if specified) and the underlying BERT model's | |||||
| sequence length. | |||||
| never_split: List of tokens which will never be split during tokenization. | |||||
| Only has an effect when do_wordpiece_only=False | |||||
| """ | |||||
| if not os.path.isfile(vocab_file): | |||||
| raise ValueError( | |||||
| "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " | |||||
| 'model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`' | |||||
| .format(vocab_file)) | |||||
| self.vocab = load_vocab(vocab_file) | |||||
| self.ids_to_tokens = collections.OrderedDict([ | |||||
| (ids, tok) for tok, ids in self.vocab.items() | |||||
| ]) | |||||
| self.do_basic_tokenize = do_basic_tokenize | |||||
| if do_basic_tokenize: | |||||
| self.basic_tokenizer = BasicTokenizer( | |||||
| do_lower_case=do_lower_case, never_split=never_split) | |||||
| self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) | |||||
| self.max_len = max_len if max_len is not None else int(1e12) | |||||
| def tokenize(self, text): | |||||
| split_tokens = [] | |||||
| if self.do_basic_tokenize: | |||||
| for token in self.basic_tokenizer.tokenize(text): | |||||
| for sub_token in self.wordpiece_tokenizer.tokenize(token): | |||||
| split_tokens.append(sub_token) | |||||
| else: | |||||
| split_tokens = self.wordpiece_tokenizer.tokenize(text) | |||||
| return split_tokens | |||||
| def convert_tokens_to_ids(self, tokens): | |||||
| """Converts a sequence of tokens into ids using the vocab.""" | |||||
| ids = [] | |||||
| for token in tokens: | |||||
| ids.append(self.vocab[token]) | |||||
| if len(ids) > self.max_len: | |||||
| logger.warning( | |||||
| 'Token indices sequence length is longer than the specified maximum ' | |||||
| ' sequence length for this BERT model ({} > {}). Running this' | |||||
| ' sequence through BERT will result in indexing errors'.format( | |||||
| len(ids), self.max_len)) | |||||
| return ids | |||||
| def convert_ids_to_tokens(self, ids): | |||||
| """Converts a sequence of ids in wordpiece tokens using the vocab.""" | |||||
| tokens = [] | |||||
| for i in ids: | |||||
| tokens.append(self.ids_to_tokens[i]) | |||||
| return tokens | |||||
| class BasicTokenizer(object): | |||||
| """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" | |||||
| def __init__(self, | |||||
| do_lower_case=True, | |||||
| never_split=('[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]')): | |||||
| """Constructs a BasicTokenizer. | |||||
| Args: | |||||
| do_lower_case: Whether to lower case the input. | |||||
| """ | |||||
| self.do_lower_case = do_lower_case | |||||
| self.never_split = never_split | |||||
| def tokenize(self, text): | |||||
| """Tokenizes a piece of text.""" | |||||
| text = self._clean_text(text) | |||||
| # This was added on November 1st, 2018 for the multilingual and Chinese | |||||
| # models. This is also applied to the English models now, but it doesn't | |||||
| # matter since the English models were not trained on any Chinese data | |||||
| # and generally don't have any Chinese data in them (there are Chinese | |||||
| # characters in the vocabulary because Wikipedia does have some Chinese | |||||
| # words in the English Wikipedia.). | |||||
| text = self._tokenize_chinese_chars(text) | |||||
| orig_tokens = whitespace_tokenize(text) | |||||
| split_tokens = [] | |||||
| for token in orig_tokens: | |||||
| if self.do_lower_case and token not in self.never_split: | |||||
| token = token.lower() | |||||
| token = self._run_strip_accents(token) | |||||
| split_tokens.extend(self._run_split_on_punc(token)) | |||||
| output_tokens = whitespace_tokenize(' '.join(split_tokens)) | |||||
| return output_tokens | |||||
| def _run_strip_accents(self, text): | |||||
| """Strips accents from a piece of text.""" | |||||
| text = unicodedata.normalize('NFD', text) | |||||
| output = [] | |||||
| for char in text: | |||||
| cat = unicodedata.category(char) | |||||
| if cat == 'Mn': | |||||
| continue | |||||
| output.append(char) | |||||
| return ''.join(output) | |||||
| def _run_split_on_punc(self, text): | |||||
| """Splits punctuation on a piece of text.""" | |||||
| if text in self.never_split: | |||||
| return [text] | |||||
| chars = list(text) | |||||
| i = 0 | |||||
| start_new_word = True | |||||
| output = [] | |||||
| while i < len(chars): | |||||
| char = chars[i] | |||||
| if _is_punctuation(char): | |||||
| output.append([char]) | |||||
| start_new_word = True | |||||
| else: | |||||
| if start_new_word: | |||||
| output.append([]) | |||||
| start_new_word = False | |||||
| output[-1].append(char) | |||||
| i += 1 | |||||
| return [''.join(x) for x in output] | |||||
| def _tokenize_chinese_chars(self, text): | |||||
| """Adds whitespace around any CJK character.""" | |||||
| output = [] | |||||
| for char in text: | |||||
| cp = ord(char) | |||||
| if self._is_chinese_char(cp): | |||||
| output.append(' ') | |||||
| output.append(char) | |||||
| output.append(' ') | |||||
| else: | |||||
| output.append(char) | |||||
| return ''.join(output) | |||||
| def _is_chinese_char(self, cp): | |||||
| """Checks whether CP is the codepoint of a CJK character.""" | |||||
| # This defines a "chinese character" as anything in the CJK Unicode block: | |||||
| # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) | |||||
| # | |||||
| # Note that the CJK Unicode block is NOT all Japanese and Korean characters, | |||||
| # despite its name. The modern Korean Hangul alphabet is a different block, | |||||
| # as is Japanese Hiragana and Katakana. Those alphabets are used to write | |||||
| # space-separated words, so they are not treated specially and handled | |||||
| # like the all of the other languages. | |||||
| tmp = (cp >= 0x4E00 and cp <= 0x9FFF) | |||||
| tmp = tmp or (cp >= 0x3400 and cp <= 0x4DBF) | |||||
| tmp = tmp or (cp >= 0x20000 and cp <= 0x2A6DF) | |||||
| tmp = tmp or (cp >= 0x2A700 and cp <= 0x2B73F) | |||||
| tmp = tmp or (cp >= 0x2B740 and cp <= 0x2B81F) | |||||
| tmp = tmp or (cp >= 0x2B820 and cp <= 0x2CEAF) | |||||
| tmp = tmp or (cp >= 0xF900 and cp <= 0xFAFF) | |||||
| tmp = tmp or (cp >= 0x2F800 and cp <= 0x2FA1F) | |||||
| if tmp: | |||||
| return True | |||||
| return False | |||||
| def _clean_text(self, text): | |||||
| """Performs invalid character removal and whitespace cleanup on text.""" | |||||
| output = [] | |||||
| for char in text: | |||||
| cp = ord(char) | |||||
| if cp == 0 or cp == 0xfffd or _is_control(char): | |||||
| continue | |||||
| if _is_whitespace(char): | |||||
| output.append(' ') | |||||
| else: | |||||
| output.append(char) | |||||
| return ''.join(output) | |||||
| class WordpieceTokenizer(object): | |||||
| """Runs WordPiece tokenization.""" | |||||
| def __init__(self, vocab, unk_token='[UNK]', max_input_chars_per_word=100): | |||||
| self.vocab = vocab | |||||
| self.unk_token = unk_token | |||||
| self.max_input_chars_per_word = max_input_chars_per_word | |||||
| def tokenize(self, text): | |||||
| """Tokenizes a piece of text into its word pieces. | |||||
| This uses a greedy longest-match-first algorithm to perform tokenization | |||||
| using the given vocabulary. | |||||
| For example: | |||||
| input = "unaffable" | |||||
| output = ["un", "##aff", "##able"] | |||||
| Args: | |||||
| text: A single token or whitespace separated tokens. This should have | |||||
| already been passed through `BasicTokenizer`. | |||||
| Returns: | |||||
| A list of wordpiece tokens. | |||||
| """ | |||||
| output_tokens = [] | |||||
| for token in whitespace_tokenize(text): | |||||
| chars = list(token) | |||||
| if len(chars) > self.max_input_chars_per_word: | |||||
| output_tokens.append(self.unk_token) | |||||
| continue | |||||
| is_bad = False | |||||
| start = 0 | |||||
| sub_tokens = [] | |||||
| while start < len(chars): | |||||
| end = len(chars) | |||||
| cur_substr = None | |||||
| while start < end: | |||||
| substr = ''.join(chars[start:end]) | |||||
| if start > 0: | |||||
| substr = '##' + substr | |||||
| if substr in self.vocab: | |||||
| cur_substr = substr | |||||
| break | |||||
| end -= 1 | |||||
| if cur_substr is None: | |||||
| is_bad = True | |||||
| break | |||||
| sub_tokens.append(cur_substr) | |||||
| start = end | |||||
| if is_bad: | |||||
| output_tokens.append(self.unk_token) | |||||
| else: | |||||
| output_tokens.extend(sub_tokens) | |||||
| return output_tokens | |||||
| def _is_whitespace(char): | |||||
| """Checks whether `chars` is a whitespace character.""" | |||||
| # \t, \n, and \r are technically contorl characters but we treat them | |||||
| # as whitespace since they are generally considered as such. | |||||
| if char == ' ' or char == '\t' or char == '\n' or char == '\r': | |||||
| return True | |||||
| cat = unicodedata.category(char) | |||||
| if cat == 'Zs': | |||||
| return True | |||||
| return False | |||||
| def _is_control(char): | |||||
| """Checks whether `chars` is a control character.""" | |||||
| # These are technically control characters but we count them as whitespace | |||||
| # characters. | |||||
| if char == '\t' or char == '\n' or char == '\r': | |||||
| return False | |||||
| cat = unicodedata.category(char) | |||||
| if cat.startswith('C'): | |||||
| return True | |||||
| return False | |||||
| def _is_punctuation(char): | |||||
| """Checks whether `chars` is a punctuation character.""" | |||||
| cp = ord(char) | |||||
| # We treat all non-letter/number ASCII as punctuation. | |||||
| # Characters such as "^", "$", and "`" are not in the Unicode | |||||
| # Punctuation class but we treat them as punctuation anyways, for | |||||
| # consistency. | |||||
| tmp = (cp >= 33 and cp <= 47) | |||||
| tmp = tmp or (cp >= 58 and cp <= 64) | |||||
| tmp = tmp or (cp >= 91 and cp <= 96) | |||||
| tmp = tmp or (cp >= 123 and cp <= 126) | |||||
| if tmp: | |||||
| return True | |||||
| cat = unicodedata.category(char) | |||||
| if cat.startswith('P'): | |||||
| return True | |||||
| return False | |||||
| # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Tokenization classes for OpenAI GPT.""" | |||||
| try: | |||||
| from functools import lru_cache | |||||
| except ImportError: | |||||
| # Just a dummy decorator to get the checks to run on python2 | |||||
| # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. | |||||
| def lru_cache(): | |||||
| return lambda func: func | |||||
| @lru_cache() | |||||
| def bytes_to_unicode(): | |||||
| """ | |||||
| Returns list of utf-8 byte and a corresponding list of unicode strings. | |||||
| The reversible bpe codes work on unicode strings. | |||||
| This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. | |||||
| When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. | |||||
| This is a signficant percentage of your normal, say, 32K bpe vocab. | |||||
| To avoid that, we want lookup tables between utf-8 bytes and unicode strings. | |||||
| And avoids mapping to whitespace/control characters the bpe code barfs on. | |||||
| """ | |||||
| _chr = unichr if sys.version_info[0] == 2 else chr | |||||
| bs = list(range(ord('!'), | |||||
| ord('~') + 1)) + list(range( | |||||
| ord('¡'), | |||||
| ord('¬') + 1)) + list(range(ord('®'), | |||||
| ord('ÿ') + 1)) | |||||
| cs = bs[:] | |||||
| n = 0 | |||||
| for b in range(2**8): | |||||
| if b not in bs: | |||||
| bs.append(b) | |||||
| cs.append(2**8 + n) | |||||
| n += 1 | |||||
| cs = [_chr(n) for n in cs] | |||||
| return dict(zip(bs, cs)) | |||||
| def get_pairs(word): | |||||
| """Return set of symbol pairs in a word. | |||||
| Word is represented as tuple of symbols (symbols being variable-length strings). | |||||
| """ | |||||
| pairs = set() | |||||
| prev_char = word[0] | |||||
| for char in word[1:]: | |||||
| pairs.add((prev_char, char)) | |||||
| prev_char = char | |||||
| return pairs | |||||
| class GPT2Tokenizer(object): | |||||
| """ | |||||
| GPT-2 BPE tokenizer. Peculiarities: | |||||
| - Byte-level BPE | |||||
| """ | |||||
| def __init__(self, | |||||
| vocab_file, | |||||
| merges_file, | |||||
| errors='replace', | |||||
| special_tokens=None, | |||||
| max_len=None): | |||||
| self.max_len = max_len if max_len is not None else int(1e12) | |||||
| self.encoder = json.load(open(vocab_file)) | |||||
| self.decoder = {v: k for k, v in self.encoder.items()} | |||||
| self.errors = errors # how to handle errors in decoding | |||||
| self.byte_encoder = bytes_to_unicode() | |||||
| self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} | |||||
| bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] | |||||
| bpe_merges = [tuple(merge.split()) for merge in bpe_data] | |||||
| self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) | |||||
| self.cache = {} | |||||
| # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions | |||||
| self.pat = re.compile( | |||||
| r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" | |||||
| ) | |||||
| self.special_tokens = {} | |||||
| self.special_tokens_decoder = {} | |||||
| self.set_special_tokens(special_tokens) | |||||
| def __len__(self): | |||||
| return len(self.encoder) + len(self.special_tokens) | |||||
| def set_special_tokens(self, special_tokens): | |||||
| """ Add a list of additional tokens to the encoder. | |||||
| The additional tokens are indexed starting from the last index of the | |||||
| current vocabulary in the order of the `special_tokens` list. | |||||
| """ | |||||
| if not special_tokens: | |||||
| self.special_tokens = {} | |||||
| self.special_tokens_decoder = {} | |||||
| return | |||||
| self.special_tokens = dict((tok, len(self.encoder) + i) | |||||
| for i, tok in enumerate(special_tokens)) | |||||
| self.special_tokens_decoder = { | |||||
| v: k | |||||
| for k, v in self.special_tokens.items() | |||||
| } | |||||
| logger.info('Special tokens {}'.format(self.special_tokens)) | |||||
| def bpe(self, token): | |||||
| if token in self.cache: | |||||
| return self.cache[token] | |||||
| word = tuple(token) | |||||
| pairs = get_pairs(word) | |||||
| if not pairs: | |||||
| return token | |||||
| while True: | |||||
| bigram = min( | |||||
| pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) | |||||
| if bigram not in self.bpe_ranks: | |||||
| break | |||||
| first, second = bigram | |||||
| new_word = [] | |||||
| i = 0 | |||||
| while i < len(word): | |||||
| try: | |||||
| j = word.index(first, i) | |||||
| new_word.extend(word[i:j]) | |||||
| i = j | |||||
| except Exception: | |||||
| new_word.extend(word[i:]) | |||||
| break | |||||
| if word[i] == first and i < len(word) - 1 and word[ | |||||
| i + 1] == second: | |||||
| new_word.append(first + second) | |||||
| i += 2 | |||||
| else: | |||||
| new_word.append(word[i]) | |||||
| i += 1 | |||||
| new_word = tuple(new_word) | |||||
| word = new_word | |||||
| if len(word) == 1: | |||||
| break | |||||
| else: | |||||
| pairs = get_pairs(word) | |||||
| word = ' '.join(word) | |||||
| self.cache[token] = word | |||||
| return word | |||||
| def tokenize(self, text): | |||||
| """ Tokenize a string. """ | |||||
| bpe_tokens = [] | |||||
| for token in re.findall(self.pat, text): | |||||
| token = ''.join(self.byte_encoder[ord(b)] for b in token | |||||
| if ord(b) in self.byte_encoder) | |||||
| if token == '': | |||||
| continue | |||||
| bpe_tokens.extend( | |||||
| bpe_token for bpe_token in self.bpe(token).split(' ')) | |||||
| return bpe_tokens | |||||
| def convert_tokens_to_ids(self, tokens): | |||||
| """ Converts a sequence of tokens into ids using the vocab. """ | |||||
| ids = [] | |||||
| python_version_3 = isinstance(tokens, str) | |||||
| python_version_2 = ( | |||||
| sys.version_info[0] == 2 and isinstance(tokens, unicode)) | |||||
| if python_version_3 or python_version_2: | |||||
| if tokens in self.special_tokens: | |||||
| return self.special_tokens[tokens] | |||||
| else: | |||||
| return self.encoder.get(tokens, 0) | |||||
| for token in tokens: | |||||
| if token in self.special_tokens: | |||||
| ids.append(self.special_tokens[token]) | |||||
| else: | |||||
| ids.append(self.encoder.get(token, 0)) | |||||
| if len(ids) > self.max_len: | |||||
| logger.warning( | |||||
| 'Token indices sequence length is longer than the specified maximum ' | |||||
| ' sequence length for this OpenAI GPT model ({} > {}). Running this' | |||||
| ' sequence through the model will result in indexing errors'. | |||||
| format(len(ids), self.max_len)) | |||||
| return ids | |||||
| def convert_ids_to_tokens(self, ids, skip_special_tokens=False): | |||||
| """Converts a sequence of ids in BPE tokens using the vocab.""" | |||||
| tokens = [] | |||||
| for i in ids: | |||||
| if i in self.special_tokens_decoder: | |||||
| if not skip_special_tokens: | |||||
| tokens.append(self.special_tokens_decoder[i]) | |||||
| else: | |||||
| tokens.append(self.decoder[i]) | |||||
| return tokens | |||||
| def encode(self, text): | |||||
| return self.convert_tokens_to_ids(self.tokenize(text)) | |||||
| def decode(self, tokens): | |||||
| text = ''.join([self.decoder[token] for token in tokens]) | |||||
| text = bytearray([self.byte_decoder[c] for c in text]).decode( | |||||
| 'utf-8', errors=self.errors) | |||||
| return text | |||||
| @@ -0,0 +1,73 @@ | |||||
| """ | |||||
| MetricsTracker class | |||||
| """ | |||||
| import math | |||||
| from collections import defaultdict | |||||
| class MetricsTracker(object): | |||||
| """ Tracking metrics. """ | |||||
| def __init__(self): | |||||
| self.metrics_val = defaultdict(float) # for one batch | |||||
| self.metrics_avg = defaultdict(float) # avg batches | |||||
| self.num_samples = 0 | |||||
| def update(self, metrics, num_samples): | |||||
| for key, val in metrics.items(): | |||||
| if val is not None: | |||||
| val = float(val) # [val] -> val | |||||
| self.metrics_val[key] = val | |||||
| avg_val = \ | |||||
| (self.metrics_avg.get(key, 0) * self.num_samples + val * num_samples) / \ | |||||
| (self.num_samples + num_samples) | |||||
| self.metrics_avg[key] = avg_val | |||||
| self.num_samples += num_samples | |||||
| def clear(self): | |||||
| self.metrics_val = defaultdict(float) | |||||
| self.metrics_avg = defaultdict(float) | |||||
| self.num_samples = 0 | |||||
| def items(self): | |||||
| return self.metrics_avg.items() | |||||
| def get(self, name): | |||||
| if self.num_samples == 0: | |||||
| raise ValueError('There is no data in Metrics.') | |||||
| return self.metrics_avg.get(name) | |||||
| def state_dict(self): | |||||
| return { | |||||
| 'metrics_val': self.metrics_val, | |||||
| 'metrics_avg': self.metrics_avg, | |||||
| 'num_samples': self.num_samples, | |||||
| } | |||||
| def load_state_dict(self, state_dict): | |||||
| self.metrics_val = state_dict['metrics_val'] | |||||
| self.metrics_avg = state_dict['metrics_avg'] | |||||
| self.num_samples = state_dict['num_samples'] | |||||
| def value(self): | |||||
| metric_strs = [] | |||||
| for key, val in self.metrics_val.items(): | |||||
| metric_str = f'{key.upper()}-{val:.3f}' | |||||
| metric_strs.append(metric_str) | |||||
| if 'token_nll' in self.metrics_val: | |||||
| metric_str = f"TOKEN_PPL-{math.exp(self.metrics_val['token_nll']):.3f}" | |||||
| metric_strs.append(metric_str) | |||||
| metric_strs = ' '.join(metric_strs) | |||||
| return metric_strs | |||||
| def summary(self): | |||||
| metric_strs = [] | |||||
| for key, val in self.metrics_avg.items(): | |||||
| metric_str = f'{key.upper()}-{val:.3f}' | |||||
| metric_strs.append(metric_str) | |||||
| if 'token_nll' in self.metrics_avg: | |||||
| metric_str = f"TOKEN_PPL-{math.exp(self.metrics_avg['token_nll']):.3f}" | |||||
| metric_strs.append(metric_str) | |||||
| metric_strs = ' '.join(metric_strs) | |||||
| return metric_strs | |||||
| @@ -0,0 +1,761 @@ | |||||
| """ | |||||
| Trainer class. | |||||
| """ | |||||
| import logging | |||||
| import os | |||||
| import sys | |||||
| import time | |||||
| from collections import OrderedDict | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | |||||
| from tqdm import tqdm | |||||
| from transformers.optimization import AdamW, get_linear_schedule_with_warmup | |||||
| from .....utils.nlp.space import ontology | |||||
| from ..metrics.metrics_tracker import MetricsTracker | |||||
| def get_logger(log_path, name='default'): | |||||
| logger = logging.getLogger(name) | |||||
| logger.propagate = False | |||||
| logger.setLevel(logging.DEBUG) | |||||
| formatter = logging.Formatter('%(message)s') | |||||
| sh = logging.StreamHandler(sys.stdout) | |||||
| sh.setFormatter(formatter) | |||||
| logger.addHandler(sh) | |||||
| fh = logging.FileHandler(log_path, mode='w') | |||||
| fh.setFormatter(formatter) | |||||
| logger.addHandler(fh) | |||||
| return logger | |||||
| class Trainer(object): | |||||
| def __init__(self, | |||||
| model, | |||||
| to_tensor, | |||||
| config, | |||||
| logger=None, | |||||
| lr_scheduler=None, | |||||
| optimizer=None, | |||||
| reader=None, | |||||
| evaluator=None): | |||||
| self.to_tensor = to_tensor | |||||
| self.do_train = config.do_train | |||||
| self.do_infer = config.do_infer | |||||
| self.is_decreased_valid_metric = config.Trainer.valid_metric_name[ | |||||
| 0] == '-' | |||||
| self.valid_metric_name = config.Trainer.valid_metric_name[1:] | |||||
| self.num_epochs = config.Trainer.num_epochs | |||||
| # self.save_dir = config.Trainer.save_dir | |||||
| self.log_steps = config.Trainer.log_steps | |||||
| self.valid_steps = config.Trainer.valid_steps | |||||
| self.save_checkpoint = config.Trainer.save_checkpoint | |||||
| self.save_summary = config.Trainer.save_summary | |||||
| self.lr = config.Model.lr | |||||
| self.weight_decay = config.Model.weight_decay | |||||
| self.batch_size = config.Trainer.batch_size | |||||
| self.gradient_accumulation_steps = config.Model.gradient_accumulation_steps | |||||
| self.warmup_steps = config.Model.warmup_steps | |||||
| self.gpu = config.Trainer.gpu | |||||
| self.lr_scheduler = lr_scheduler | |||||
| self.optimizer = optimizer | |||||
| self.model = model | |||||
| self.func_model = self.model.module if self.gpu > 1 else self.model | |||||
| self.reader = reader | |||||
| self.evaluator = evaluator | |||||
| self.tokenizer = reader.tokenizer | |||||
| # if not os.path.exists(self.save_dir): | |||||
| # os.makedirs(self.save_dir) | |||||
| # self.logger = logger or get_logger(os.path.join(self.save_dir, "trainer.log"), "trainer") | |||||
| self.logger = logger or get_logger('trainer.log', 'trainer') | |||||
| self.batch_metrics_tracker = MetricsTracker() | |||||
| self.token_metrics_tracker = MetricsTracker() | |||||
| self.best_valid_metric = float( | |||||
| 'inf' if self.is_decreased_valid_metric else '-inf') | |||||
| self.epoch = 0 | |||||
| def decode_generated_bspn_resp(self, generated): | |||||
| """ | |||||
| decode generated | |||||
| return decoded ('bspn', 'resp') | |||||
| """ | |||||
| decoded = {} | |||||
| eos_r_id = self.reader.eos_r_id | |||||
| eos_b_id = self.reader.eos_b_id | |||||
| # eos_r may not exists if gpt2 generated repetitive words. | |||||
| if eos_r_id in generated: | |||||
| eos_r_idx = generated.index(eos_r_id) | |||||
| else: | |||||
| eos_r_idx = len(generated) - 1 | |||||
| # self.logger.info('eos_r not in generated: ' + self.tokenizer.decode(generated)) | |||||
| # predicted bspn, resp | |||||
| eos_b_idx = generated.index(eos_b_id) | |||||
| decoded['bspn'] = generated[:eos_b_idx + 1] | |||||
| decoded['resp'] = generated[eos_b_idx + 1:eos_r_idx + 1] | |||||
| return decoded | |||||
| def decode_generated_act_resp(self, generated): | |||||
| """ | |||||
| decode generated | |||||
| return decoded['resp'] ('bspn', 'aspn') | |||||
| """ | |||||
| decoded = {} | |||||
| eos_a_id = self.reader.eos_a_id | |||||
| eos_r_id = self.reader.eos_r_id | |||||
| # eos_b_id = self.reader.eos_b_id | |||||
| # eos_r may not exists if gpt2 generated repetitive words. | |||||
| if eos_r_id in generated: | |||||
| eos_r_idx = generated.index(eos_r_id) | |||||
| else: | |||||
| eos_r_idx = len(generated) - 1 | |||||
| msg = 'eos_r not in generated: ' + self.tokenizer.decode(generated) | |||||
| self.logger.info(msg) | |||||
| if self.reader.use_true_curr_aspn: # only predict resp | |||||
| decoded['resp'] = generated[:eos_r_idx + 1] | |||||
| else: # predicted aspn, resp | |||||
| eos_a_idx = generated.index(eos_a_id) | |||||
| decoded['aspn'] = generated[:eos_a_idx + 1] | |||||
| decoded['resp'] = generated[eos_a_idx + 1:eos_r_idx + 1] | |||||
| return decoded | |||||
| def decode_generated_bspn(self, generated): | |||||
| eos_b_id = self.reader.eos_b_id | |||||
| if eos_b_id in generated: | |||||
| eos_b_idx = generated.index(eos_b_id) | |||||
| else: | |||||
| eos_b_idx = len(generated) - 1 | |||||
| return generated[:eos_b_idx + 1] | |||||
| def set_optimizers(self): | |||||
| """ | |||||
| Setup the optimizer and the learning rate scheduler. | |||||
| from transformers.Trainer | |||||
| parameters from cfg: lr (1e-3); warmup_steps | |||||
| """ | |||||
| # Prepare optimizer and schedule (linear warmup and decay) | |||||
| no_decay = ['bias', 'norm.weight'] | |||||
| optimizer_grouped_parameters = [ | |||||
| { | |||||
| 'params': [ | |||||
| p for n, p in self.model.named_parameters() | |||||
| if not any(nd in n for nd in no_decay) | |||||
| ], | |||||
| 'weight_decay': | |||||
| self.weight_decay, | |||||
| }, | |||||
| { | |||||
| 'params': [ | |||||
| p for n, p in self.model.named_parameters() | |||||
| if any(nd in n for nd in no_decay) | |||||
| ], | |||||
| 'weight_decay': | |||||
| 0.0, | |||||
| }, | |||||
| ] | |||||
| optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr) | |||||
| num_training_steps = \ | |||||
| self.reader.set_stats['train']['num_training_steps_per_epoch'] \ | |||||
| * self.num_epochs \ | |||||
| // self.gradient_accumulation_steps | |||||
| num_warmup_steps = self.warmup_steps if self.warmup_steps >= 0 else int( | |||||
| num_training_steps * 0.1) | |||||
| lr_scheduler = get_linear_schedule_with_warmup( | |||||
| optimizer, | |||||
| num_warmup_steps=num_warmup_steps, | |||||
| num_training_steps=num_training_steps) | |||||
| self.optimizer = optimizer | |||||
| self.lr_scheduler = lr_scheduler | |||||
| def train(self, train_data, dev_data): | |||||
| # log info | |||||
| set_stats = self.reader.set_stats['train'] | |||||
| self.logger.info('***** Running training *****') | |||||
| self.logger.info( | |||||
| ' Num Training steps(one turn in a batch of dialogs) per epoch = %d', | |||||
| set_stats['num_training_steps_per_epoch']) | |||||
| self.logger.info(' Num Turns = %d', set_stats['num_turns']) | |||||
| self.logger.info(' Num Dialogs = %d', set_stats['num_dials']) | |||||
| self.logger.info(' Num Epochs = %d', self.num_epochs) | |||||
| self.logger.info(' Batch size = %d', self.batch_size) | |||||
| self.logger.info(' Gradient Accumulation steps = %d', | |||||
| self.gradient_accumulation_steps) | |||||
| steps = set_stats[ | |||||
| 'num_training_steps_per_epoch'] * self.num_epochs // self.gradient_accumulation_steps | |||||
| msg = ' Total optimization steps = %d' % steps | |||||
| self.logger.info(msg) | |||||
| # begin training | |||||
| num_epochs = self.num_epochs - self.epoch | |||||
| for epoch in range(num_epochs): | |||||
| self.train_epoch(train_data=train_data, dev_data=dev_data) | |||||
| def train_epoch(self, train_data, dev_data): | |||||
| """ | |||||
| Train an epoch. | |||||
| """ | |||||
| raise NotImplementedError | |||||
| def infer(self, data_type): | |||||
| """ | |||||
| Inference interface. | |||||
| """ | |||||
| raise NotImplementedError | |||||
| def forward(self, turn, old_pv_turn): | |||||
| """ | |||||
| one turn inference | |||||
| """ | |||||
| raise NotImplementedError | |||||
| def save(self, is_best=False): | |||||
| """ save """ | |||||
| train_state = { | |||||
| 'epoch': self.epoch, | |||||
| 'best_valid_metric': self.best_valid_metric, | |||||
| 'optimizer': self.optimizer.state_dict() | |||||
| } | |||||
| if self.lr_scheduler is not None: | |||||
| train_state['lr_scheduler'] = self.lr_scheduler.state_dict() | |||||
| # Save checkpoint | |||||
| if self.save_checkpoint: | |||||
| model_file = os.path.join(self.save_dir, | |||||
| f'state_epoch_{self.epoch}.model') | |||||
| torch.save(self.model.state_dict(), model_file) | |||||
| self.logger.info(f"Saved model state to '{model_file}'") | |||||
| train_file = os.path.join(self.save_dir, | |||||
| f'state_epoch_{self.epoch}.train') | |||||
| torch.save(train_state, train_file) | |||||
| self.logger.info(f"Saved train state to '{train_file}'") | |||||
| # Save current best model | |||||
| if is_best: | |||||
| best_model_file = os.path.join(self.save_dir, 'best.model') | |||||
| torch.save(self.model.state_dict(), best_model_file) | |||||
| best_train_file = os.path.join(self.save_dir, 'best.train') | |||||
| torch.save(train_state, best_train_file) | |||||
| self.logger.info( | |||||
| f"Saved best model state to '{best_model_file}' with new best valid metric " | |||||
| f'{self.valid_metric_name.upper()}={self.best_valid_metric:.3f}' | |||||
| ) | |||||
| def load(self): | |||||
| """ load """ | |||||
| def _load_model_state(): | |||||
| model_state_dict = torch.load( | |||||
| f'{self.func_model.init_checkpoint}', | |||||
| map_location=lambda storage, loc: storage) | |||||
| if 'module.' in list(model_state_dict.keys())[0]: | |||||
| new_model_state_dict = OrderedDict() | |||||
| for k, v in model_state_dict.items(): | |||||
| assert k[:7] == 'module.' | |||||
| new_model_state_dict[k[7:]] = v | |||||
| model_state_dict = new_model_state_dict | |||||
| new_model_state_dict = OrderedDict() | |||||
| parameters = { | |||||
| name: param | |||||
| for name, param in self.func_model.named_parameters() | |||||
| } | |||||
| for name, param in model_state_dict.items(): | |||||
| if name in parameters: | |||||
| if param.shape != parameters[name].shape: | |||||
| assert hasattr(param, 'numpy') | |||||
| arr = param.numpy() | |||||
| z = np.random.normal( | |||||
| scale=self.func_model.initializer_range, | |||||
| size=parameters[name].shape).astype('float32') | |||||
| if name == 'embedder.token_embedding.weight': | |||||
| z[-param.shape[0]:] = arr | |||||
| print( | |||||
| f'part of parameter({name}) random normlize initialize' | |||||
| ) | |||||
| else: | |||||
| if z.shape[0] < param.shape[0]: | |||||
| z = arr[:z.shape[0]] | |||||
| print(f'part of parameter({name}) are dropped') | |||||
| else: | |||||
| z[:param.shape[0]] = arr | |||||
| print( | |||||
| f'part of parameter({name}) random normlize initialize' | |||||
| ) | |||||
| dtype, device = param.dtype, param.device | |||||
| z = torch.tensor(z, dtype=dtype, device=device) | |||||
| new_model_state_dict[name] = z | |||||
| else: | |||||
| new_model_state_dict[name] = param | |||||
| else: | |||||
| print(f'parameter({name}) are dropped') | |||||
| model_state_dict = new_model_state_dict | |||||
| for name in parameters: | |||||
| if name not in model_state_dict: | |||||
| if parameters[name].requires_grad: | |||||
| print(f'parameter({name}) random normlize initialize') | |||||
| z = np.random.normal( | |||||
| scale=self.func_model.initializer_range, | |||||
| size=parameters[name].shape).astype('float32') | |||||
| dtype, device = parameters[name].dtype, parameters[ | |||||
| name].device | |||||
| model_state_dict[name] = torch.tensor( | |||||
| z, dtype=dtype, device=device) | |||||
| else: | |||||
| model_state_dict[name] = parameters[name] | |||||
| self.func_model.load_state_dict(model_state_dict) | |||||
| self.logger.info( | |||||
| f"Loaded model state from '{self.func_model.init_checkpoint}.model'" | |||||
| ) | |||||
| def _load_train_state(): | |||||
| train_file = f'{self.func_model.init_checkpoint}.train' | |||||
| if os.path.exists(train_file): | |||||
| train_state_dict = torch.load( | |||||
| train_file, map_location=lambda storage, loc: storage) | |||||
| self.epoch = train_state_dict['epoch'] | |||||
| self.best_valid_metric = train_state_dict['best_valid_metric'] | |||||
| if self.optimizer is not None and 'optimizer' in train_state_dict: | |||||
| self.optimizer.load_state_dict( | |||||
| train_state_dict['optimizer']) | |||||
| if self.lr_scheduler is not None and 'lr_scheduler' in train_state_dict: | |||||
| self.lr_scheduler.load_state_dict( | |||||
| train_state_dict['lr_scheduler']) | |||||
| self.logger.info( | |||||
| f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " | |||||
| f'best_valid_metric={self.best_valid_metric:.3f})') | |||||
| else: | |||||
| self.logger.info('Loaded no train state') | |||||
| if self.func_model.init_checkpoint is None: | |||||
| self.logger.info('Loaded no model !!!') | |||||
| return | |||||
| if self.do_train: | |||||
| _load_model_state() | |||||
| return | |||||
| if self.do_infer: | |||||
| _load_model_state() | |||||
| _load_train_state() | |||||
| class MultiWOZTrainer(Trainer): | |||||
| def __init__(self, | |||||
| model, | |||||
| to_tensor, | |||||
| config, | |||||
| logger=None, | |||||
| lr_scheduler=None, | |||||
| optimizer=None, | |||||
| reader=None, | |||||
| evaluator=None): | |||||
| super(MultiWOZTrainer, | |||||
| self).__init__(model, to_tensor, config, logger, lr_scheduler, | |||||
| optimizer, reader, evaluator) | |||||
| def train_epoch(self, train_data, dev_data): | |||||
| """ | |||||
| Train an epoch. | |||||
| """ | |||||
| times = [] | |||||
| epoch_step = 0 | |||||
| global_step = 0 | |||||
| tr_batch_loss = 0.0 | |||||
| tr_token_loss = 0.0 | |||||
| self.epoch += 1 | |||||
| self.batch_metrics_tracker.clear() | |||||
| self.token_metrics_tracker.clear() | |||||
| num_training_steps = \ | |||||
| self.reader.set_stats['train']['num_training_steps_per_epoch'] // \ | |||||
| self.gradient_accumulation_steps # similar to the original num_batches | |||||
| self.model.zero_grad() | |||||
| data_iterator = self.reader.get_data_iterator(all_batches=train_data) | |||||
| for batch_idx, dial_batch in enumerate(data_iterator): | |||||
| pv_batch = [] | |||||
| for turn_num, turn_batch in enumerate(dial_batch): | |||||
| first_turn = (turn_num == 0) | |||||
| samples, pv_batch = self.reader.convert_batch_turn( | |||||
| turn_batch, pv_batch, first_turn) | |||||
| batch, batch_size = self.reader.collate_fn_multi_turn( | |||||
| samples=samples) | |||||
| batch = type(batch)( | |||||
| map(lambda kv: (kv[0], self.to_tensor(kv[1])), | |||||
| batch.items())) | |||||
| # Do a training iteration | |||||
| start_time = time.time() | |||||
| metrics = self.model(batch, is_training=True) | |||||
| if self.gpu > 1: | |||||
| for metric in metrics: | |||||
| if metric is not None: | |||||
| assert len(metric) == self.gpu | |||||
| nll, token_nll, token_num = metrics | |||||
| metrics = {} | |||||
| token_num = torch.sum(token_num) | |||||
| token_nll = \ | |||||
| torch.sum(nll) * (batch_size / self.gpu) / \ | |||||
| token_num | |||||
| nll = torch.mean(nll) | |||||
| metrics['token_num'] = token_num | |||||
| metrics['token_nll'] = token_nll | |||||
| metrics['nll'] = nll | |||||
| loss = token_nll if self.func_model.token_loss else nll | |||||
| metrics['loss'] = loss | |||||
| else: | |||||
| loss = metrics['loss'] | |||||
| self.func_model._optimize( | |||||
| loss, do_update=False, optimizer=self.optimizer) | |||||
| metrics = { | |||||
| k: v.cpu().detach().numpy() | |||||
| if isinstance(v, torch.Tensor) else v | |||||
| for k, v in metrics.items() | |||||
| } | |||||
| token_num = metrics.pop('token_num', None) | |||||
| # bow_num = metrics.pop("bow_num", None) | |||||
| elapsed = time.time() - start_time | |||||
| times.append(elapsed) | |||||
| epoch_step += 1 | |||||
| tr_batch_loss += metrics['nll'] | |||||
| tr_token_loss += metrics['token_nll'] | |||||
| batch_metrics = { | |||||
| k: v | |||||
| for k, v in metrics.items() if 'token' not in k | |||||
| } | |||||
| token_metrics = { | |||||
| k: v | |||||
| for k, v in metrics.items() if 'token' in k | |||||
| } | |||||
| self.batch_metrics_tracker.update(batch_metrics, batch_size) | |||||
| self.token_metrics_tracker.update(token_metrics, token_num) | |||||
| if (epoch_step % self.gradient_accumulation_steps == 0) or \ | |||||
| (epoch_step == self.reader.set_stats['train']['num_training_steps_per_epoch']): | |||||
| self.optimizer.step() | |||||
| self.lr_scheduler.step() | |||||
| self.optimizer.zero_grad() | |||||
| global_step += 1 | |||||
| if self.log_steps > 0 and global_step % self.log_steps == 0: | |||||
| batch_metrics_message = self.batch_metrics_tracker.value( | |||||
| ) | |||||
| token_metrics_message = self.token_metrics_tracker.value( | |||||
| ) | |||||
| message_prefix = f'[Train][{self.epoch}][{global_step}/{num_training_steps}]' | |||||
| avg_time = f'AVG_Time-{sum(times[-self.log_steps:]) / self.log_steps:.3f}' | |||||
| message = ' '.join([ | |||||
| message_prefix, batch_metrics_message, | |||||
| token_metrics_message, avg_time | |||||
| ]) | |||||
| self.logger.info(message) | |||||
| self.logger.info('-' * 150) | |||||
| avg_batch_loss = tr_batch_loss / epoch_step | |||||
| avg_token_loss = tr_token_loss / epoch_step | |||||
| batch_metrics_message = self.batch_metrics_tracker.summary() | |||||
| token_metrics_message = self.token_metrics_tracker.summary() | |||||
| message_prefix = f'[Valid][{self.epoch}]' | |||||
| message = ' '.join([ | |||||
| message_prefix, batch_metrics_message, token_metrics_message, | |||||
| str(avg_batch_loss), | |||||
| str(avg_token_loss) | |||||
| ]) | |||||
| self.logger.info(message) | |||||
| cur_valid_metric = self.batch_metrics_tracker.get( | |||||
| self.valid_metric_name) | |||||
| if self.is_decreased_valid_metric: | |||||
| is_best = cur_valid_metric < self.best_valid_metric | |||||
| else: | |||||
| is_best = cur_valid_metric > self.best_valid_metric | |||||
| if is_best: | |||||
| self.best_valid_metric = cur_valid_metric | |||||
| self.save(is_best) | |||||
| self.logger.info('-' * 150) | |||||
| return | |||||
| def infer(self, data_type='test'): | |||||
| """ | |||||
| Inference interface. | |||||
| """ | |||||
| self.logger.info('Generation starts ...') | |||||
| infer_save_file = os.path.join(self.save_dir, | |||||
| f'infer_{self.epoch}.result.json') | |||||
| infer_samples_save_file = os.path.join( | |||||
| self.save_dir, f'infer_samples_{self.epoch}.result.json') | |||||
| # Inference | |||||
| result_collection = {} | |||||
| begin_time = time.time() | |||||
| eval_data = self.reader.get_eval_data(data_type) | |||||
| set_stats = self.reader.set_stats[data_type] | |||||
| self.logger.info('***** Running Evaluation *****') | |||||
| self.logger.info(' Num Turns = %d', set_stats['num_turns']) | |||||
| with torch.no_grad(): | |||||
| pbar = tqdm(eval_data) | |||||
| for dial_idx, dialog in enumerate(pbar): | |||||
| pv_turn = {} | |||||
| for turn_idx, turn in enumerate(dialog): | |||||
| first_turn = (turn_idx == 0) | |||||
| inputs, prompt_id = self.reader.convert_turn_eval( | |||||
| turn, pv_turn, first_turn) | |||||
| batch, batch_size = self.reader.collate_fn_multi_turn( | |||||
| samples=[inputs]) | |||||
| batch = type(batch)( | |||||
| map(lambda kv: (kv[0], self.to_tensor(kv[1])), | |||||
| batch.items())) | |||||
| if self.reader.use_true_curr_bspn: # generate act, response | |||||
| max_len = 60 | |||||
| if not self.reader.use_true_curr_aspn: | |||||
| max_len = 80 | |||||
| outputs = self.func_model.infer( | |||||
| inputs=batch, | |||||
| start_id=prompt_id, | |||||
| eos_id=self.reader.eos_r_id, | |||||
| max_gen_len=max_len) | |||||
| # resp_gen, need to trim previous context | |||||
| generated = outputs[0].cpu().numpy().tolist() | |||||
| try: | |||||
| decoded = self.decode_generated_act_resp(generated) | |||||
| except ValueError as exception: | |||||
| self.logger.info(str(exception)) | |||||
| self.logger.info(self.tokenizer.decode(generated)) | |||||
| decoded = {'resp': [], 'bspn': [], 'aspn': []} | |||||
| else: # predict bspn, access db, then generate act and resp | |||||
| outputs = self.func_model.infer( | |||||
| inputs=batch, | |||||
| start_id=prompt_id, | |||||
| eos_id=self.reader.eos_b_id, | |||||
| max_gen_len=60) | |||||
| generated_bs = outputs[0].cpu().numpy().tolist() | |||||
| bspn_gen = self.decode_generated_bspn(generated_bs) | |||||
| # check DB result | |||||
| if self.reader.use_true_db_pointer: # To control whether current db is ground truth | |||||
| db = turn['db'] | |||||
| else: | |||||
| db_result = self.reader.bspan_to_DBpointer( | |||||
| self.tokenizer.decode(bspn_gen), | |||||
| turn['turn_domain']) | |||||
| assert len(turn['db']) == 4 | |||||
| book_result = turn['db'][2] | |||||
| assert isinstance(db_result, str) | |||||
| db = \ | |||||
| [self.reader.sos_db_id] + \ | |||||
| self.tokenizer.convert_tokens_to_ids([db_result]) + \ | |||||
| [book_result] + \ | |||||
| [self.reader.eos_db_id] | |||||
| prompt_id = self.reader.sos_a_id | |||||
| prev_input = torch.tensor(bspn_gen + db) | |||||
| if self.func_model.use_gpu: | |||||
| prev_input = prev_input.cuda() | |||||
| outputs_db = self.func_model.infer( | |||||
| inputs=batch, | |||||
| start_id=prompt_id, | |||||
| eos_id=self.reader.eos_r_id, | |||||
| max_gen_len=80, | |||||
| prev_input=prev_input) | |||||
| generated_ar = outputs_db[0].cpu().numpy().tolist() | |||||
| try: | |||||
| decoded = self.decode_generated_act_resp( | |||||
| generated_ar) | |||||
| decoded['bspn'] = bspn_gen | |||||
| except ValueError as exception: | |||||
| self.logger.info(str(exception)) | |||||
| self.logger.info( | |||||
| self.tokenizer.decode(generated_ar)) | |||||
| decoded = {'resp': [], 'bspn': [], 'aspn': []} | |||||
| turn['resp_gen'] = decoded['resp'] | |||||
| turn['bspn_gen'] = turn[ | |||||
| 'bspn'] if self.reader.use_true_curr_bspn else decoded[ | |||||
| 'bspn'] | |||||
| turn['aspn_gen'] = turn[ | |||||
| 'aspn'] if self.reader.use_true_curr_aspn else decoded[ | |||||
| 'aspn'] | |||||
| turn['dspn_gen'] = turn['dspn'] | |||||
| pv_turn['labels'] = inputs[ | |||||
| 'labels'] # all true previous context | |||||
| pv_turn['resp'] = turn[ | |||||
| 'resp'] if self.reader.use_true_prev_resp else decoded[ | |||||
| 'resp'] | |||||
| if not self.reader.use_true_curr_bspn: | |||||
| pv_turn['bspn'] = turn[ | |||||
| 'bspn'] if self.reader.use_true_prev_bspn else decoded[ | |||||
| 'bspn'] | |||||
| pv_turn['db'] = turn[ | |||||
| 'db'] if self.reader.use_true_prev_bspn else db | |||||
| pv_turn['aspn'] = turn[ | |||||
| 'aspn'] if self.reader.use_true_prev_aspn else decoded[ | |||||
| 'aspn'] | |||||
| tmp_dialog_result = self.reader.inverse_transpose_turn(dialog) | |||||
| result_collection.update(tmp_dialog_result) | |||||
| # compute tmp scores | |||||
| results, _ = self.reader.wrap_result_lm(tmp_dialog_result) | |||||
| bleu, success, match = self.evaluator.validation_metric( | |||||
| results) | |||||
| score = 0.5 * (success + match) + bleu | |||||
| pbar.set_description( | |||||
| 'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' % | |||||
| (match, success, bleu, score)) | |||||
| # compute scores | |||||
| results, _ = self.reader.wrap_result_lm(result_collection) | |||||
| bleu, success, match = self.evaluator.validation_metric(results) | |||||
| score = 0.5 * (success + match) + bleu | |||||
| # log results | |||||
| metrics_message = 'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' %\ | |||||
| (match, success, bleu, score) | |||||
| message_prefix = f'[Infer][{self.epoch}]' | |||||
| time_cost = f'TIME-{time.time() - begin_time:.3f}' | |||||
| message = ' '.join([message_prefix, metrics_message, time_cost]) | |||||
| self.logger.info(message) | |||||
| # save results | |||||
| eval_results = { | |||||
| 'bleu': bleu, | |||||
| 'success': success, | |||||
| 'match': match, | |||||
| 'score': score, | |||||
| 'result': message | |||||
| } | |||||
| with open(infer_save_file, 'w') as fp: | |||||
| json.dump(eval_results, fp, indent=2) | |||||
| self.logger.info(f'Saved inference results to {infer_save_file}') | |||||
| with open(infer_samples_save_file, 'w') as fp: | |||||
| for sample in results: | |||||
| line = json.dumps(sample) | |||||
| fp.write(line) | |||||
| fp.write('\n') | |||||
| self.logger.info( | |||||
| f'Saved inference samples to {infer_samples_save_file}') | |||||
| return | |||||
| def _get_turn_domain(self, old_pv_turn, bspn_gen_ids, first_turn): | |||||
| def _get_slots(constraint): | |||||
| domain_name = '' | |||||
| slots = {} | |||||
| for item in constraint: | |||||
| if item in ontology.placeholder_tokens: | |||||
| continue | |||||
| if item in ontology.all_domains_with_bracket: | |||||
| domain_name = item | |||||
| slots[domain_name] = set() | |||||
| else: | |||||
| assert domain_name in ontology.all_domains_with_bracket | |||||
| slots[domain_name].add(item) | |||||
| return slots | |||||
| turn_domain = [] | |||||
| if first_turn and len(bspn_gen_ids) == 0: | |||||
| turn_domain = ['[general]'] | |||||
| return turn_domain | |||||
| bspn_token = self.tokenizer.convert_ids_to_tokens(bspn_gen_ids) | |||||
| turn_slots = _get_slots(bspn_token) | |||||
| if first_turn: | |||||
| return list(turn_slots.keys()) | |||||
| assert 'bspn' in old_pv_turn | |||||
| pv_bspn_token = self.tokenizer.convert_ids_to_tokens( | |||||
| old_pv_turn['bspn']) | |||||
| pv_turn_slots = _get_slots(pv_bspn_token) | |||||
| for domain, value in turn_slots.items(): | |||||
| pv_value = pv_turn_slots[ | |||||
| domain] if domain in pv_turn_slots else set() | |||||
| if len(value - pv_value) > 0 or len(pv_value - value): | |||||
| turn_domain.append(domain) | |||||
| if len(turn_domain) == 0: | |||||
| turn_domain = list(turn_slots.keys()) | |||||
| return turn_domain | |||||
| def forward(self, turn, old_pv_turn): | |||||
| with torch.no_grad(): | |||||
| first_turn = True if len(old_pv_turn) == 0 else False | |||||
| inputs, prompt_id = self.reader.convert_turn_eval( | |||||
| turn, old_pv_turn, first_turn) | |||||
| batch, batch_size = self.reader.collate_fn_multi_turn( | |||||
| samples=[inputs]) | |||||
| batch = type(batch)( | |||||
| map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items())) | |||||
| pv_turn = {} | |||||
| outputs = self.func_model.infer( | |||||
| inputs=batch, | |||||
| start_id=prompt_id, | |||||
| eos_id=self.reader.eos_b_id, | |||||
| max_gen_len=60) | |||||
| generated_bs = outputs[0].cpu().numpy().tolist() | |||||
| bspn_gen = self.decode_generated_bspn(generated_bs) | |||||
| turn_domain = self._get_turn_domain(old_pv_turn, bspn_gen, | |||||
| first_turn) | |||||
| db_result = self.reader.bspan_to_DBpointer( | |||||
| self.tokenizer.decode(bspn_gen), turn_domain) | |||||
| assert isinstance(db_result, str) | |||||
| db = \ | |||||
| [self.reader.sos_db_id] + \ | |||||
| self.tokenizer.convert_tokens_to_ids([db_result]) + \ | |||||
| [self.reader.eos_db_id] | |||||
| prompt_id = self.reader.sos_a_id | |||||
| prev_input = torch.tensor(bspn_gen + db) | |||||
| if self.func_model.use_gpu: | |||||
| prev_input = prev_input.cuda() | |||||
| outputs_db = self.func_model.infer( | |||||
| inputs=batch, | |||||
| start_id=prompt_id, | |||||
| eos_id=self.reader.eos_r_id, | |||||
| max_gen_len=80, | |||||
| prev_input=prev_input) | |||||
| generated_ar = outputs_db[0].cpu().numpy().tolist() | |||||
| decoded = self.decode_generated_act_resp(generated_ar) | |||||
| decoded['bspn'] = bspn_gen | |||||
| pv_turn['labels'] = inputs['labels'] | |||||
| pv_turn['resp'] = decoded['resp'] | |||||
| pv_turn['bspn'] = decoded['bspn'] | |||||
| pv_turn['db'] = db | |||||
| pv_turn['aspn'] = decoded['aspn'] | |||||
| return pv_turn | |||||
| @@ -0,0 +1,821 @@ | |||||
| """ | |||||
| Trainer class. | |||||
| """ | |||||
| import logging | |||||
| import os | |||||
| import sys | |||||
| import time | |||||
| from collections import OrderedDict | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | |||||
| from tqdm import tqdm | |||||
| from transformers.optimization import AdamW, get_linear_schedule_with_warmup | |||||
| from ..metrics.metrics_tracker import MetricsTracker | |||||
| def get_logger(log_path, name='default'): | |||||
| logger = logging.getLogger(name) | |||||
| logger.propagate = False | |||||
| logger.setLevel(logging.DEBUG) | |||||
| formatter = logging.Formatter('%(message)s') | |||||
| sh = logging.StreamHandler(sys.stdout) | |||||
| sh.setFormatter(formatter) | |||||
| logger.addHandler(sh) | |||||
| fh = logging.FileHandler(log_path, mode='w') | |||||
| fh.setFormatter(formatter) | |||||
| logger.addHandler(fh) | |||||
| return logger | |||||
| class Trainer(object): | |||||
| def __init__(self, | |||||
| model, | |||||
| to_tensor, | |||||
| config, | |||||
| reader=None, | |||||
| logger=None, | |||||
| lr_scheduler=None, | |||||
| optimizer=None): | |||||
| self.model = model | |||||
| self.to_tensor = to_tensor | |||||
| self.do_train = config.do_train | |||||
| self.do_infer = config.do_infer | |||||
| self.is_decreased_valid_metric = config.Trainer.valid_metric_name[ | |||||
| 0] == '-' | |||||
| self.valid_metric_name = config.Trainer.valid_metric_name[1:] | |||||
| self.num_epochs = config.Trainer.num_epochs | |||||
| self.save_dir = config.Trainer.save_dir | |||||
| self.log_steps = config.Trainer.log_steps | |||||
| self.valid_steps = config.Trainer.valid_steps | |||||
| self.save_checkpoint = config.Trainer.save_checkpoint | |||||
| self.save_summary = config.Trainer.save_summary | |||||
| self.learning_method = config.Dataset.learning_method | |||||
| self.weight_decay = config.Model.weight_decay | |||||
| self.warmup_steps = config.Model.warmup_steps | |||||
| self.batch_size_label = config.Trainer.batch_size_label | |||||
| self.batch_size_nolabel = config.Trainer.batch_size_nolabel | |||||
| self.gpu = config.Trainer.gpu | |||||
| self.lr = config.Model.lr | |||||
| self.model = model | |||||
| self.func_model = self.model.module if self.gpu > 1 else self.model | |||||
| self.reader = reader | |||||
| self.tokenizer = reader.tokenizer | |||||
| self.lr_scheduler = lr_scheduler | |||||
| self.optimizer = optimizer | |||||
| # if not os.path.exists(self.save_dir): | |||||
| # os.makedirs(self.save_dir) | |||||
| # self.logger = logger or get_logger(os.path.join(self.save_dir, "trainer.log"), "trainer") | |||||
| self.logger = logger or get_logger('trainer.log', 'trainer') | |||||
| self.batch_metrics_tracker_label = MetricsTracker() | |||||
| self.token_metrics_tracker_label = MetricsTracker() | |||||
| self.batch_metrics_tracker_nolabel = MetricsTracker() | |||||
| self.token_metrics_tracker_nolabel = MetricsTracker() | |||||
| self.best_valid_metric = float( | |||||
| 'inf' if self.is_decreased_valid_metric else '-inf') | |||||
| self.epoch = 0 | |||||
| self.batch_num = 0 | |||||
| def set_optimizers(self, num_training_steps_per_epoch): | |||||
| """ | |||||
| Setup the optimizer and the learning rate scheduler. | |||||
| from transformers.Trainer | |||||
| parameters from cfg: lr (1e-3); warmup_steps | |||||
| """ | |||||
| # Prepare optimizer and schedule (linear warmup and decay) | |||||
| no_decay = ['bias', 'norm.weight'] | |||||
| optimizer_grouped_parameters = [ | |||||
| { | |||||
| 'params': [ | |||||
| p for n, p in self.model.named_parameters() | |||||
| if not any(nd in n for nd in no_decay) | |||||
| ], | |||||
| 'weight_decay': | |||||
| self.weight_decay, | |||||
| }, | |||||
| { | |||||
| 'params': [ | |||||
| p for n, p in self.model.named_parameters() | |||||
| if any(nd in n for nd in no_decay) | |||||
| ], | |||||
| 'weight_decay': | |||||
| 0.0, | |||||
| }, | |||||
| ] | |||||
| optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr) | |||||
| num_training_steps = num_training_steps_per_epoch * self.num_epochs | |||||
| num_warmup_steps = self.warmup_steps if self.warmup_steps >= 0 else int( | |||||
| num_training_steps * 0.1) | |||||
| lr_scheduler = get_linear_schedule_with_warmup( | |||||
| optimizer, | |||||
| num_warmup_steps=num_warmup_steps, | |||||
| num_training_steps=num_training_steps) | |||||
| # reset optimizer and lr_scheduler | |||||
| self.optimizer = optimizer | |||||
| self.lr_scheduler = lr_scheduler | |||||
| # log info | |||||
| self.logger.info( | |||||
| f'***** Running training: {self.learning_method} *****') | |||||
| self.logger.info(' Num Epochs = %d', self.num_epochs) | |||||
| self.logger.info( | |||||
| ' Num Training steps(one turn in a batch of dialogs) per epoch = %d', | |||||
| num_training_steps_per_epoch) | |||||
| self.logger.info(' Batch size for labeled data = %d', | |||||
| self.batch_size_label) | |||||
| self.logger.info(' Batch size for unlabeled data = %d', | |||||
| self.batch_size_nolabel) | |||||
| self.logger.info(' Total optimization steps = %d', num_training_steps) | |||||
| self.logger.info(' Total warmup steps = %d', num_warmup_steps) | |||||
| self.logger.info('************************************') | |||||
| def train(self, | |||||
| train_label_iter, | |||||
| train_nolabel_iter=None, | |||||
| valid_label_iter=None, | |||||
| valid_nolabel_iter=None): | |||||
| # begin training | |||||
| num_epochs = self.num_epochs - self.epoch | |||||
| for epoch in range(num_epochs): | |||||
| self.train_epoch( | |||||
| train_label_iter=train_label_iter, | |||||
| train_nolabel_iter=train_nolabel_iter, | |||||
| valid_label_iter=valid_label_iter, | |||||
| valid_nolabel_iter=valid_nolabel_iter) | |||||
| def train_epoch(self, train_label_iter, train_nolabel_iter, | |||||
| valid_label_iter, valid_nolabel_iter): | |||||
| """ | |||||
| Train an epoch. | |||||
| """ | |||||
| raise NotImplementedError | |||||
| def evaluate(self, data_label_iter, data_nolabel_iter, need_save=True): | |||||
| raise NotImplementedError | |||||
| def infer(self, data_iter, num_batches=None): | |||||
| raise NotImplementedError | |||||
| def save(self, is_best=False): | |||||
| """ save """ | |||||
| train_state = { | |||||
| 'epoch': self.epoch, | |||||
| 'batch_num': self.batch_num, | |||||
| 'best_valid_metric': self.best_valid_metric, | |||||
| 'optimizer': self.optimizer.state_dict() | |||||
| } | |||||
| if self.lr_scheduler is not None: | |||||
| train_state['lr_scheduler'] = self.lr_scheduler.state_dict() | |||||
| # Save checkpoint | |||||
| if self.save_checkpoint: | |||||
| model_file = os.path.join(self.save_dir, | |||||
| f'state_epoch_{self.epoch}.model') | |||||
| torch.save(self.model.state_dict(), model_file) | |||||
| self.logger.info(f"Saved model state to '{model_file}'") | |||||
| train_file = os.path.join(self.save_dir, | |||||
| f'state_epoch_{self.epoch}.train') | |||||
| torch.save(train_state, train_file) | |||||
| self.logger.info(f"Saved train state to '{train_file}'") | |||||
| # Save current best model | |||||
| if is_best: | |||||
| best_model_file = os.path.join(self.save_dir, 'best.model') | |||||
| torch.save(self.model.state_dict(), best_model_file) | |||||
| best_train_file = os.path.join(self.save_dir, 'best.train') | |||||
| torch.save(train_state, best_train_file) | |||||
| self.logger.info( | |||||
| f"Saved best model state to '{best_model_file}' with new best valid metric " | |||||
| f'{self.valid_metric_name.upper()}={self.best_valid_metric:.3f}' | |||||
| ) | |||||
| def load(self): | |||||
| """ load """ | |||||
| def _load_model_state(): | |||||
| model_state_dict = torch.load( | |||||
| f'{self.func_model.init_checkpoint}.model', | |||||
| map_location=lambda storage, loc: storage) | |||||
| if 'module.' in list(model_state_dict.keys())[0]: | |||||
| new_model_state_dict = OrderedDict() | |||||
| for k, v in model_state_dict.items(): | |||||
| assert k[:7] == 'module.' | |||||
| new_model_state_dict[k[7:]] = v | |||||
| model_state_dict = new_model_state_dict | |||||
| new_model_state_dict = OrderedDict() | |||||
| parameters = { | |||||
| name: param | |||||
| for name, param in self.func_model.named_parameters() | |||||
| } | |||||
| for name, param in model_state_dict.items(): | |||||
| if name in parameters: | |||||
| if param.shape != parameters[name].shape: | |||||
| assert hasattr(param, 'numpy') | |||||
| arr = param.numpy() | |||||
| z = np.random.normal( | |||||
| scale=self.func_model.initializer_range, | |||||
| size=parameters[name].shape).astype('float32') | |||||
| if name == 'embedder.token_embedding.weight': | |||||
| z[-param.shape[0]:] = arr | |||||
| print( | |||||
| f'part of parameter({name}) random normlize initialize' | |||||
| ) | |||||
| else: | |||||
| if z.shape[0] < param.shape[0]: | |||||
| z = arr[:z.shape[0]] | |||||
| print(f'part of parameter({name}) are dropped') | |||||
| else: | |||||
| z[:param.shape[0]] = arr | |||||
| print( | |||||
| f'part of parameter({name}) random normlize initialize' | |||||
| ) | |||||
| dtype, device = param.dtype, param.device | |||||
| z = torch.tensor(z, dtype=dtype, device=device) | |||||
| new_model_state_dict[name] = z | |||||
| else: | |||||
| new_model_state_dict[name] = param | |||||
| else: | |||||
| print(f'parameter({name}) are dropped') | |||||
| model_state_dict = new_model_state_dict | |||||
| for name in parameters: | |||||
| if name not in model_state_dict: | |||||
| if parameters[name].requires_grad: | |||||
| print(f'parameter({name}) random normlize initialize') | |||||
| z = np.random.normal( | |||||
| scale=self.func_model.initializer_range, | |||||
| size=parameters[name].shape).astype('float32') | |||||
| dtype, device = parameters[name].dtype, parameters[ | |||||
| name].device | |||||
| model_state_dict[name] = torch.tensor( | |||||
| z, dtype=dtype, device=device) | |||||
| else: | |||||
| model_state_dict[name] = parameters[name] | |||||
| self.func_model.load_state_dict(model_state_dict) | |||||
| self.logger.info( | |||||
| f"Loaded model state from '{self.func_model.init_checkpoint}.model'" | |||||
| ) | |||||
| def _load_train_state(): | |||||
| train_file = f'{self.func_model.init_checkpoint}.train' | |||||
| if os.path.exists(train_file): | |||||
| train_state_dict = torch.load( | |||||
| train_file, map_location=lambda storage, loc: storage) | |||||
| self.epoch = train_state_dict['epoch'] | |||||
| self.best_valid_metric = train_state_dict['best_valid_metric'] | |||||
| if self.optimizer is not None and 'optimizer' in train_state_dict: | |||||
| self.optimizer.load_state_dict( | |||||
| train_state_dict['optimizer']) | |||||
| if self.lr_scheduler is not None and 'lr_scheduler' in train_state_dict: | |||||
| self.lr_scheduler.load_state_dict( | |||||
| train_state_dict['lr_scheduler']) | |||||
| self.logger.info( | |||||
| f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " | |||||
| f'best_valid_metric={self.best_valid_metric:.3f})') | |||||
| else: | |||||
| self.logger.info('Loaded no train state') | |||||
| if self.func_model.init_checkpoint is None: | |||||
| self.logger.info('Loaded no model !!!') | |||||
| return | |||||
| _load_model_state() | |||||
| _load_train_state() | |||||
| class IntentTrainer(Trainer): | |||||
| def __init__(self, model, to_tensor, config, reader=None): | |||||
| super(IntentTrainer, self).__init__(model, to_tensor, config, reader) | |||||
| self.example = config.Model.example | |||||
| self.can_norm = config.Trainer.can_norm | |||||
| def can_normalization(self, y_pred, y_true, ex_data_iter): | |||||
| # compute ACC | |||||
| acc_original = np.mean([y_pred.argmax(1) == y_true]) | |||||
| message = 'original acc: %s' % acc_original | |||||
| # compute uncertainty | |||||
| k = 3 | |||||
| y_pred_topk = np.sort(y_pred, axis=1)[:, -k:] | |||||
| y_pred_topk /= y_pred_topk.sum(axis=1, keepdims=True) | |||||
| y_pred_uncertainty =\ | |||||
| -(y_pred_topk * np.log(y_pred_topk)).sum(1) / np.log(k) | |||||
| # choose threshold | |||||
| # print(np.sort(y_pred_uncertainty)[-100:].tolist()) | |||||
| threshold = 0.7 | |||||
| y_pred_confident = y_pred[y_pred_uncertainty < threshold] | |||||
| y_pred_unconfident = y_pred[y_pred_uncertainty >= threshold] | |||||
| y_true_confident = y_true[y_pred_uncertainty < threshold] | |||||
| y_true_unconfident = y_true[y_pred_uncertainty >= threshold] | |||||
| # compute ACC again for high and low confidence sets | |||||
| acc_confident = (y_pred_confident.argmax(1) == y_true_confident).mean() \ | |||||
| if len(y_true_confident) else 0. | |||||
| acc_unconfident = (y_pred_unconfident.argmax(1) == y_true_unconfident).mean() \ | |||||
| if len(y_true_unconfident) else 0. | |||||
| message += ' (%s) confident acc: %s' % (len(y_true_confident), | |||||
| acc_confident) | |||||
| message += ' (%s) unconfident acc: %s' % (len(y_true_unconfident), | |||||
| acc_unconfident) | |||||
| # get prior distribution from training set | |||||
| prior = np.zeros(self.func_model.num_intent) | |||||
| for _, (batch, batch_size) in ex_data_iter: | |||||
| for intent_label in batch['intent_label']: | |||||
| prior[intent_label] += 1. | |||||
| prior /= prior.sum() | |||||
| # revise each sample from the low confidence set, and compute new ACC | |||||
| right, alpha, iters = 0, 1, 1 | |||||
| for i, y in enumerate(y_pred_unconfident): | |||||
| Y = np.concatenate([y_pred_confident, y[None]], axis=0) | |||||
| for j in range(iters): | |||||
| Y = Y**alpha | |||||
| Y /= Y.mean(axis=0, keepdims=True) | |||||
| Y *= prior[None] | |||||
| Y /= Y.sum(axis=1, keepdims=True) | |||||
| y = Y[-1] | |||||
| if y.argmax() == y_true_unconfident[i]: | |||||
| right += 1 | |||||
| # get final ACC | |||||
| acc_final = \ | |||||
| (acc_confident * len(y_pred_confident) + right) / \ | |||||
| len(y_pred) | |||||
| if len(y_pred_unconfident): | |||||
| message += ' new unconfident acc: %s' % ( | |||||
| right / len(y_pred_unconfident)) | |||||
| else: | |||||
| message += ' no unconfident predictions' | |||||
| message += ' final acc: %s' % acc_final | |||||
| return acc_original, acc_final, message | |||||
| def train_epoch(self, train_label_iter, train_nolabel_iter, | |||||
| valid_label_iter, valid_nolabel_iter): | |||||
| """ | |||||
| Train an epoch. | |||||
| """ | |||||
| times = [] | |||||
| self.epoch += 1 | |||||
| self.batch_metrics_tracker_label.clear() | |||||
| self.token_metrics_tracker_label.clear() | |||||
| self.batch_metrics_tracker_nolabel.clear() | |||||
| self.token_metrics_tracker_nolabel.clear() | |||||
| num_label_batches = len(train_label_iter) | |||||
| num_nolabel_batches = len( | |||||
| train_nolabel_iter) if train_nolabel_iter is not None else 0 | |||||
| num_batches = max(num_label_batches, num_nolabel_batches) | |||||
| train_label_iter_loop = iter(train_label_iter) | |||||
| train_nolabel_iter_loop = iter( | |||||
| train_nolabel_iter) if train_nolabel_iter is not None else None | |||||
| report_for_unlabeled_data = True if train_nolabel_iter is not None else False | |||||
| for batch_id in range(1, num_batches + 1): | |||||
| # Do a training iteration | |||||
| start_time = time.time() | |||||
| batch_list, batch_size_list, with_label_list, loss_list, metrics_list = [], [], [], [], [] | |||||
| data_file_list = [] | |||||
| # collect batch for labeled data | |||||
| try: | |||||
| data_file_label, ( | |||||
| batch_label, | |||||
| batch_size_label) = next(train_label_iter_loop) | |||||
| except StopIteration: | |||||
| train_label_iter_loop = iter(train_label_iter) | |||||
| data_file_label, ( | |||||
| batch_label, | |||||
| batch_size_label) = next(train_label_iter_loop) | |||||
| batch_list.append(batch_label) | |||||
| batch_size_list.append(batch_size_label) | |||||
| with_label_list.append(True) | |||||
| data_file_list.append(data_file_label) | |||||
| # collect batch for unlabeled data | |||||
| if train_nolabel_iter is not None: | |||||
| try: | |||||
| data_file_nolabel, ( | |||||
| batch_nolabel, | |||||
| batch_size_nolabel) = next(train_nolabel_iter_loop) | |||||
| except StopIteration: | |||||
| train_nolabel_iter_loop = iter(train_nolabel_iter) | |||||
| data_file_nolabel, ( | |||||
| batch_nolabel, | |||||
| batch_size_nolabel) = next(train_nolabel_iter_loop) | |||||
| batch_list.append(batch_nolabel) | |||||
| batch_size_list.append(batch_size_nolabel) | |||||
| with_label_list.append(False) | |||||
| data_file_list.append(data_file_nolabel) | |||||
| # forward labeled batch and unlabeled batch and collect outputs, respectively | |||||
| for (batch, batch_size, with_label, data_file) in \ | |||||
| zip(batch_list, batch_size_list, with_label_list, data_file_list): | |||||
| batch = type(batch)( | |||||
| map(lambda kv: (kv[0], self.to_tensor(kv[1])), | |||||
| batch.items())) | |||||
| if self.example and with_label: | |||||
| current_dataset = train_label_iter.data_file_to_dataset[ | |||||
| data_file] | |||||
| example_batch = self.reader.retrieve_examples( | |||||
| dataset=current_dataset, | |||||
| labels=batch['intent_label'], | |||||
| inds=batch['ids'], | |||||
| task='intent') | |||||
| example_batch = type(example_batch)( | |||||
| map(lambda kv: (kv[0], self.to_tensor(kv[1])), | |||||
| example_batch.items())) | |||||
| for k, v in example_batch.items(): | |||||
| batch[k] = v | |||||
| batch['epoch'] = self.epoch | |||||
| batch['num_steps'] = self.batch_num | |||||
| metrics = self.model( | |||||
| batch, | |||||
| is_training=True, | |||||
| with_label=with_label, | |||||
| data_file=data_file) | |||||
| loss, metrics = self.balance_metrics( | |||||
| metrics=metrics, batch_size=batch_size) | |||||
| loss_list.append(loss) | |||||
| metrics_list.append(metrics) | |||||
| # combine loss for labeled data and unlabeled data | |||||
| # TODO change the computation of combined loss of labeled batch and unlabeled batch | |||||
| loss = loss_list[0] if len( | |||||
| loss_list) == 1 else loss_list[0] + loss_list[1] | |||||
| # optimization procedure | |||||
| self.func_model._optimize( | |||||
| loss, optimizer=self.optimizer, lr_scheduler=self.lr_scheduler) | |||||
| elapsed = time.time() - start_time | |||||
| times.append(elapsed) | |||||
| self.batch_num += 1 | |||||
| # track metrics and log temporary message | |||||
| for (batch_size, metrics, | |||||
| with_label) in zip(batch_size_list, metrics_list, | |||||
| with_label_list): | |||||
| self.track_and_log_message( | |||||
| metrics=metrics, | |||||
| batch_id=batch_id, | |||||
| batch_size=batch_size, | |||||
| num_batches=num_batches, | |||||
| times=times, | |||||
| with_label=with_label) | |||||
| # evaluate | |||||
| if self.valid_steps > 0 and valid_label_iter is not None and valid_nolabel_iter is not None \ | |||||
| and batch_id % self.valid_steps == 0: | |||||
| self.evaluate( | |||||
| data_label_iter=valid_label_iter, | |||||
| data_nolabel_iter=valid_nolabel_iter) | |||||
| # compute accuracy for valid dataset | |||||
| accuracy = self.infer( | |||||
| data_iter=valid_label_iter, ex_data_iter=train_label_iter) | |||||
| # report summary message and save checkpoints | |||||
| self.save_and_log_message( | |||||
| report_for_unlabeled_data, cur_valid_metric=-accuracy) | |||||
| def forward(self, batch): | |||||
| pred = [] | |||||
| with torch.no_grad(): | |||||
| batch = type(batch)( | |||||
| map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items())) | |||||
| result = self.model.infer(inputs=batch) | |||||
| result = { | |||||
| name: result[name].cpu().detach().numpy() | |||||
| for name in result | |||||
| } | |||||
| intent_probs = result['intent_probs'] | |||||
| if self.can_norm: | |||||
| pred += [intent_probs] | |||||
| else: | |||||
| pred += np.argmax(intent_probs, axis=1).tolist() | |||||
| return pred | |||||
| def infer(self, data_iter, num_batches=None, ex_data_iter=None): | |||||
| """ | |||||
| Inference interface. | |||||
| """ | |||||
| self.logger.info('Generation starts ...') | |||||
| infer_save_file = os.path.join(self.save_dir, | |||||
| f'infer_{self.epoch}.result.json') | |||||
| # Inference | |||||
| batch_cnt = 0 | |||||
| pred, true = [], [] | |||||
| outputs, labels = [], [] | |||||
| begin_time = time.time() | |||||
| with torch.no_grad(): | |||||
| if self.example: | |||||
| for _, (batch, batch_size) in tqdm( | |||||
| ex_data_iter, desc='Building train memory.'): | |||||
| batch = type(batch)( | |||||
| map(lambda kv: (kv[0], self.to_tensor(kv[1])), | |||||
| batch.items())) | |||||
| result = self.model.infer(inputs=batch) | |||||
| result = { | |||||
| name: result[name].cpu().detach().numpy() | |||||
| for name in result | |||||
| } | |||||
| outputs.append(torch.from_numpy(result['features'])) | |||||
| labels += batch['intent_label'].tolist() | |||||
| mem = torch.cat(outputs, dim=0) | |||||
| mem = mem.cuda() if self.func_model.use_gpu else mem | |||||
| labels = torch.LongTensor(labels).unsqueeze(0) | |||||
| labels = labels.cuda() if self.func_model.use_gpu else labels | |||||
| self.logger.info(f'Memory size: {mem.size()}') | |||||
| for _, (batch, batch_size) in tqdm(data_iter, total=num_batches): | |||||
| batch = type(batch)( | |||||
| map(lambda kv: (kv[0], self.to_tensor(kv[1])), | |||||
| batch.items())) | |||||
| result = self.model.infer(inputs=batch) | |||||
| result = { | |||||
| name: result[name].cpu().detach().numpy() | |||||
| for name in result | |||||
| } | |||||
| if self.example: | |||||
| features = torch.from_numpy(result['features']) | |||||
| features = features.cuda( | |||||
| ) if self.func_model.use_gpu else features | |||||
| probs = torch.softmax(features.mm(mem.t()), dim=-1) | |||||
| intent_probs = torch.zeros( | |||||
| probs.size(0), self.func_model.num_intent) | |||||
| intent_probs = intent_probs.cuda( | |||||
| ) if self.func_model.use_gpu else intent_probs | |||||
| intent_probs = intent_probs.scatter_add( | |||||
| -1, labels.repeat(probs.size(0), 1), probs) | |||||
| intent_probs = intent_probs.cpu().detach().numpy() | |||||
| else: | |||||
| intent_probs = result['intent_probs'] | |||||
| if self.can_norm: | |||||
| pred += [intent_probs] | |||||
| true += batch['intent_label'].cpu().detach().tolist() | |||||
| else: | |||||
| pred += np.argmax(intent_probs, axis=1).tolist() | |||||
| true += batch['intent_label'].cpu().detach().tolist() | |||||
| batch_cnt += 1 | |||||
| if batch_cnt == num_batches: | |||||
| break | |||||
| if self.can_norm: | |||||
| true = np.array(true) | |||||
| pred = np.concatenate(pred, axis=0) | |||||
| acc_original, acc_final, message = self.can_normalization( | |||||
| y_pred=pred, y_true=true, ex_data_iter=ex_data_iter) | |||||
| accuracy = max(acc_original, acc_final) | |||||
| infer_results = { | |||||
| 'accuracy': accuracy, | |||||
| 'pred_labels': pred.tolist(), | |||||
| 'message': message | |||||
| } | |||||
| metrics_message = f'Accuracy: {accuracy} {message}' | |||||
| else: | |||||
| accuracy = sum(p == t for p, t in zip(pred, true)) / len(pred) | |||||
| infer_results = {'accuracy': accuracy, 'pred_labels': pred} | |||||
| metrics_message = f'Accuracy: {accuracy}' | |||||
| self.logger.info(f'Saved inference results to {infer_save_file}') | |||||
| with open(infer_save_file, 'w') as fp: | |||||
| json.dump(infer_results, fp, indent=2) | |||||
| message_prefix = f'[Infer][{self.epoch}]' | |||||
| time_cost = f'TIME-{time.time() - begin_time:.3f}' | |||||
| message = ' '.join([message_prefix, metrics_message, time_cost]) | |||||
| self.logger.info(message) | |||||
| return accuracy | |||||
| def track_and_log_message(self, metrics, batch_id, batch_size, num_batches, | |||||
| times, with_label): | |||||
| # track metrics | |||||
| batch_metrics_tracker = self.batch_metrics_tracker_label if with_label else self.batch_metrics_tracker_nolabel | |||||
| token_metrics_tracker = self.token_metrics_tracker_label if with_label else self.token_metrics_tracker_nolabel | |||||
| metrics = { | |||||
| k: v.cpu().detach().numpy() if isinstance(v, torch.Tensor) else v | |||||
| for k, v in metrics.items() | |||||
| } | |||||
| mlm_num = metrics.pop('mlm_num', 0) | |||||
| batch_metrics = {k: v for k, v in metrics.items() if 'token' not in k} | |||||
| token_metrics = {k: v for k, v in metrics.items() if 'token' in k} | |||||
| batch_metrics_tracker.update(batch_metrics, batch_size) | |||||
| token_metrics_tracker.update(token_metrics, mlm_num) | |||||
| # log message | |||||
| if self.log_steps > 0 and batch_id % self.log_steps == 0: | |||||
| batch_metrics_message = batch_metrics_tracker.value() | |||||
| token_metrics_message = token_metrics_tracker.value() | |||||
| label_prefix = 'Labeled' if with_label else 'Unlabeled' | |||||
| message_prefix = f'[Train][{self.epoch}][{batch_id}/{num_batches}][{label_prefix}]' | |||||
| avg_time = f'AVG_Time-{sum(times[-self.log_steps:]) / self.log_steps:.3f}' | |||||
| message = ' '.join([ | |||||
| message_prefix, batch_metrics_message, token_metrics_message, | |||||
| avg_time | |||||
| ]) | |||||
| self.logger.info(message) | |||||
| def save_and_log_message(self, | |||||
| report_for_unlabeled_data, | |||||
| cur_valid_metric=None): | |||||
| # report message | |||||
| batch_metrics_message = self.batch_metrics_tracker_label.summary() | |||||
| token_metrics_message = self.token_metrics_tracker_label.summary() | |||||
| message_prefix = f'[Valid][{self.epoch}][Labeled]' | |||||
| message = ' '.join( | |||||
| [message_prefix, batch_metrics_message, token_metrics_message]) | |||||
| self.logger.info(message) | |||||
| if report_for_unlabeled_data: | |||||
| batch_metrics_message = self.batch_metrics_tracker_nolabel.summary( | |||||
| ) | |||||
| token_metrics_message = self.token_metrics_tracker_nolabel.summary( | |||||
| ) | |||||
| message_prefix = f'[Valid][{self.epoch}][Unlabeled]' | |||||
| message = ' '.join( | |||||
| [message_prefix, batch_metrics_message, token_metrics_message]) | |||||
| self.logger.info(message) | |||||
| # save checkpoints | |||||
| assert cur_valid_metric is not None | |||||
| if self.is_decreased_valid_metric: | |||||
| is_best = cur_valid_metric < self.best_valid_metric | |||||
| else: | |||||
| is_best = cur_valid_metric > self.best_valid_metric | |||||
| if is_best: | |||||
| self.best_valid_metric = cur_valid_metric | |||||
| self.save(is_best) | |||||
| def balance_metrics(self, metrics, batch_size): | |||||
| if self.gpu > 1: | |||||
| for metric in metrics: | |||||
| if metric is not None: | |||||
| assert len(metric) == self.gpu | |||||
| intent_loss, mlm, token_mlm, mlm_num, kl, con = metrics | |||||
| metrics = {} | |||||
| intent_loss = torch.mean(intent_loss) | |||||
| metrics['intent_loss'] = intent_loss | |||||
| loss = intent_loss | |||||
| if mlm is not None: | |||||
| mlm_num = torch.sum(mlm_num) | |||||
| token_mlm = torch.sum(mlm) * (batch_size / self.gpu) / mlm_num | |||||
| mlm = torch.mean(mlm) | |||||
| metrics['mlm_num'] = mlm_num | |||||
| metrics['token_mlm'] = token_mlm | |||||
| metrics['mlm'] = mlm | |||||
| loss = loss + (token_mlm if self.func_model.token_loss else | |||||
| mlm) * self.func_model.mlm_ratio | |||||
| if kl is not None: | |||||
| kl = torch.mean(kl) | |||||
| metrics['kl'] = kl | |||||
| loss = loss + kl * self.func_model.kl_ratio | |||||
| if con is not None: | |||||
| con = torch.mean(con) | |||||
| metrics['con'] = con | |||||
| loss = loss + con | |||||
| metrics['loss'] = loss | |||||
| assert 'loss' in metrics | |||||
| return metrics['loss'], metrics | |||||
| def load(self): | |||||
| """ load """ | |||||
| def _load_model_state(): | |||||
| model_state_dict = torch.load( | |||||
| f'{self.func_model.init_checkpoint}', | |||||
| map_location=lambda storage, loc: storage) | |||||
| if 'module.' in list(model_state_dict.keys())[0]: | |||||
| new_model_state_dict = OrderedDict() | |||||
| for k, v in model_state_dict.items(): | |||||
| assert k[:7] == 'module.' | |||||
| new_model_state_dict[k[7:]] = v | |||||
| model_state_dict = new_model_state_dict | |||||
| new_model_state_dict = OrderedDict() | |||||
| parameters = { | |||||
| name: param | |||||
| for name, param in self.func_model.named_parameters() | |||||
| } | |||||
| for name, param in model_state_dict.items(): | |||||
| if name in parameters: | |||||
| if param.shape != parameters[name].shape: | |||||
| assert hasattr(param, 'numpy') | |||||
| arr = param.numpy() | |||||
| z = np.random.normal( | |||||
| scale=self.func_model.initializer_range, | |||||
| size=parameters[name].shape).astype('float32') | |||||
| if name == 'embedder.token_embedding.weight': | |||||
| z[-param.shape[0]:] = arr | |||||
| print( | |||||
| f'part of parameter({name}) random normlize initialize' | |||||
| ) | |||||
| else: | |||||
| if z.shape[0] < param.shape[0]: | |||||
| z = arr[:z.shape[0]] | |||||
| print(f'part of parameter({name}) are dropped') | |||||
| else: | |||||
| z[:param.shape[0]] = arr | |||||
| print( | |||||
| f'part of parameter({name}) random normlize initialize' | |||||
| ) | |||||
| dtype, device = param.dtype, param.device | |||||
| z = torch.tensor(z, dtype=dtype, device=device) | |||||
| new_model_state_dict[name] = z | |||||
| else: | |||||
| new_model_state_dict[name] = param | |||||
| else: | |||||
| print(f'parameter({name}) are dropped') | |||||
| model_state_dict = new_model_state_dict | |||||
| for name in parameters: | |||||
| if name not in model_state_dict: | |||||
| if parameters[name].requires_grad: | |||||
| print(f'parameter({name}) random normlize initialize') | |||||
| z = np.random.normal( | |||||
| scale=self.func_model.initializer_range, | |||||
| size=parameters[name].shape).astype('float32') | |||||
| dtype, device = parameters[name].dtype, parameters[ | |||||
| name].device | |||||
| model_state_dict[name] = torch.tensor( | |||||
| z, dtype=dtype, device=device) | |||||
| else: | |||||
| model_state_dict[name] = parameters[name] | |||||
| self.func_model.load_state_dict(model_state_dict) | |||||
| self.logger.info( | |||||
| f"Loaded model state from '{self.func_model.init_checkpoint}.model'" | |||||
| ) | |||||
| def _load_train_state(): | |||||
| train_file = f'{self.func_model.init_checkpoint}.train' | |||||
| if os.path.exists(train_file): | |||||
| train_state_dict = torch.load( | |||||
| train_file, map_location=lambda storage, loc: storage) | |||||
| self.epoch = train_state_dict['epoch'] | |||||
| self.best_valid_metric = train_state_dict['best_valid_metric'] | |||||
| if self.optimizer is not None and 'optimizer' in train_state_dict: | |||||
| self.optimizer.load_state_dict( | |||||
| train_state_dict['optimizer']) | |||||
| if self.lr_scheduler is not None and 'lr_scheduler' in train_state_dict: | |||||
| self.lr_scheduler.load_state_dict( | |||||
| train_state_dict['lr_scheduler']) | |||||
| self.logger.info( | |||||
| f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " | |||||
| f'best_valid_metric={self.best_valid_metric:.3f})') | |||||
| else: | |||||
| self.logger.info('Loaded no train state') | |||||
| if self.func_model.init_checkpoint is None: | |||||
| self.logger.info('Loaded no model !!!') | |||||
| return | |||||
| if self.do_train: | |||||
| _load_model_state() | |||||
| return | |||||
| if self.do_infer: | |||||
| _load_model_state() | |||||
| _load_train_state() | |||||
| @@ -34,6 +34,8 @@ class Tasks(object): | |||||
| # nlp tasks | # nlp tasks | ||||
| word_segmentation = 'word-segmentation' | word_segmentation = 'word-segmentation' | ||||
| nli = 'nli' | |||||
| sentiment_classification = 'sentiment-classification' | |||||
| sentiment_analysis = 'sentiment-analysis' | sentiment_analysis = 'sentiment-analysis' | ||||
| sentence_similarity = 'sentence-similarity' | sentence_similarity = 'sentence-similarity' | ||||
| text_classification = 'text-classification' | text_classification = 'text-classification' | ||||
| @@ -43,6 +45,8 @@ class Tasks(object): | |||||
| token_classification = 'token-classification' | token_classification = 'token-classification' | ||||
| conversational = 'conversational' | conversational = 'conversational' | ||||
| text_generation = 'text-generation' | text_generation = 'text-generation' | ||||
| dialog_modeling = 'dialog-modeling' | |||||
| dialog_intent_prediction = 'dialog-intent-prediction' | |||||
| table_question_answering = 'table-question-answering' | table_question_answering = 'table-question-answering' | ||||
| feature_extraction = 'feature-extraction' | feature_extraction = 'feature-extraction' | ||||
| fill_mask = 'fill-mask' | fill_mask = 'fill-mask' | ||||
| @@ -0,0 +1,66 @@ | |||||
| """ | |||||
| Parse argument. | |||||
| """ | |||||
| import argparse | |||||
| import json | |||||
| def str2bool(v): | |||||
| if v.lower() in ('yes', 'true', 't', 'y', '1'): | |||||
| return True | |||||
| elif v.lower() in ('no', 'false', 'f', 'n', '0'): | |||||
| return False | |||||
| else: | |||||
| raise argparse.ArgumentTypeError('Unsupported value encountered.') | |||||
| class HParams(dict): | |||||
| """ Hyper-parameters class | |||||
| Store hyper-parameters in training / infer / ... scripts. | |||||
| """ | |||||
| def __getattr__(self, name): | |||||
| if name in self.keys(): | |||||
| return self[name] | |||||
| for v in self.values(): | |||||
| if isinstance(v, HParams): | |||||
| if name in v: | |||||
| return v[name] | |||||
| raise AttributeError(f"'HParams' object has no attribute '{name}'") | |||||
| def __setattr__(self, name, value): | |||||
| self[name] = value | |||||
| def save(self, filename): | |||||
| with open(filename, 'w', encoding='utf-8') as fp: | |||||
| json.dump(self, fp, ensure_ascii=False, indent=4, sort_keys=False) | |||||
| def load(self, filename): | |||||
| with open(filename, 'r', encoding='utf-8') as fp: | |||||
| params_dict = json.load(fp) | |||||
| for k, v in params_dict.items(): | |||||
| if isinstance(v, dict): | |||||
| self[k].update(HParams(v)) | |||||
| else: | |||||
| self[k] = v | |||||
| def parse_args(parser): | |||||
| """ Parse hyper-parameters from cmdline. """ | |||||
| parsed = parser.parse_args() | |||||
| args = HParams() | |||||
| optional_args = parser._action_groups[1] | |||||
| for action in optional_args._group_actions[1:]: | |||||
| arg_name = action.dest | |||||
| args[arg_name] = getattr(parsed, arg_name) | |||||
| for group in parser._action_groups[2:]: | |||||
| group_args = HParams() | |||||
| for action in group._group_actions: | |||||
| arg_name = action.dest | |||||
| group_args[arg_name] = getattr(parsed, arg_name) | |||||
| if len(group_args) > 0: | |||||
| args[group.title] = group_args | |||||
| return args | |||||
| @@ -0,0 +1,52 @@ | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| from torch.nn.modules.loss import _Loss | |||||
| def compute_kl_loss(p, q, filter_scores=None): | |||||
| p_loss = F.kl_div( | |||||
| F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none') | |||||
| q_loss = F.kl_div( | |||||
| F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none') | |||||
| # You can choose whether to use function "sum" and "mean" depending on your task | |||||
| p_loss = p_loss.sum(dim=-1) | |||||
| q_loss = q_loss.sum(dim=-1) | |||||
| # mask is for filter mechanism | |||||
| if filter_scores is not None: | |||||
| p_loss = filter_scores * p_loss | |||||
| q_loss = filter_scores * q_loss | |||||
| p_loss = p_loss.mean() | |||||
| q_loss = q_loss.mean() | |||||
| loss = (p_loss + q_loss) / 2 | |||||
| return loss | |||||
| class CatKLLoss(_Loss): | |||||
| """ | |||||
| CatKLLoss | |||||
| """ | |||||
| def __init__(self, reduction='mean'): | |||||
| super(CatKLLoss, self).__init__() | |||||
| assert reduction in ['none', 'sum', 'mean'] | |||||
| self.reduction = reduction | |||||
| def forward(self, log_qy, log_py): | |||||
| """ | |||||
| KL(qy|py) = Eq[qy * log(q(y) / p(y))] | |||||
| log_qy: (batch_size, latent_size) | |||||
| log_py: (batch_size, latent_size) | |||||
| """ | |||||
| qy = torch.exp(log_qy) | |||||
| kl = torch.sum(qy * (log_qy - log_py), dim=1) | |||||
| if self.reduction == 'mean': | |||||
| kl = kl.mean() | |||||
| elif self.reduction == 'sum': | |||||
| kl = kl.sum() | |||||
| return kl | |||||
| @@ -0,0 +1,313 @@ | |||||
| import os | |||||
| import random | |||||
| import sqlite3 | |||||
| import json | |||||
| from .ontology import all_domains, db_domains | |||||
| class MultiWozDB(object): | |||||
| def __init__(self, db_dir, db_paths): | |||||
| self.dbs = {} | |||||
| self.sql_dbs = {} | |||||
| for domain in all_domains: | |||||
| with open(os.path.join(db_dir, db_paths[domain]), 'r') as f: | |||||
| self.dbs[domain] = json.loads(f.read().lower()) | |||||
| def oneHotVector(self, domain, num): | |||||
| """Return number of available entities for particular domain.""" | |||||
| vector = [0, 0, 0, 0] | |||||
| if num == '': | |||||
| return vector | |||||
| if domain != 'train': | |||||
| if num == 0: | |||||
| vector = [1, 0, 0, 0] | |||||
| elif num == 1: | |||||
| vector = [0, 1, 0, 0] | |||||
| elif num <= 3: | |||||
| vector = [0, 0, 1, 0] | |||||
| else: | |||||
| vector = [0, 0, 0, 1] | |||||
| else: | |||||
| if num == 0: | |||||
| vector = [1, 0, 0, 0] | |||||
| elif num <= 5: | |||||
| vector = [0, 1, 0, 0] | |||||
| elif num <= 10: | |||||
| vector = [0, 0, 1, 0] | |||||
| else: | |||||
| vector = [0, 0, 0, 1] | |||||
| return vector | |||||
| def addBookingPointer(self, turn_da): | |||||
| """Add information about availability of the booking option.""" | |||||
| # Booking pointer | |||||
| # Do not consider booking two things in a single turn. | |||||
| vector = [0, 0] | |||||
| if turn_da.get('booking-nobook'): | |||||
| vector = [1, 0] | |||||
| if turn_da.get('booking-book') or turn_da.get('train-offerbooked'): | |||||
| vector = [0, 1] | |||||
| return vector | |||||
| def addDBPointer(self, domain, match_num, return_num=False): | |||||
| """Create database pointer for all related domains.""" | |||||
| # if turn_domains is None: | |||||
| # turn_domains = db_domains | |||||
| if domain in db_domains: | |||||
| vector = self.oneHotVector(domain, match_num) | |||||
| else: | |||||
| vector = [0, 0, 0, 0] | |||||
| return vector | |||||
| def addDBIndicator(self, domain, match_num, return_num=False): | |||||
| """Create database indicator for all related domains.""" | |||||
| # if turn_domains is None: | |||||
| # turn_domains = db_domains | |||||
| if domain in db_domains: | |||||
| vector = self.oneHotVector(domain, match_num) | |||||
| else: | |||||
| vector = [0, 0, 0, 0] | |||||
| # '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]' | |||||
| if vector == [0, 0, 0, 0]: | |||||
| indicator = '[db_nores]' | |||||
| else: | |||||
| indicator = '[db_%s]' % vector.index(1) | |||||
| return indicator | |||||
| def get_match_num(self, constraints, return_entry=False): | |||||
| """Create database pointer for all related domains.""" | |||||
| match = {'general': ''} | |||||
| entry = {} | |||||
| # if turn_domains is None: | |||||
| # turn_domains = db_domains | |||||
| for domain in all_domains: | |||||
| match[domain] = '' | |||||
| if domain in db_domains and constraints.get(domain): | |||||
| matched_ents = self.queryJsons(domain, constraints[domain]) | |||||
| match[domain] = len(matched_ents) | |||||
| if return_entry: | |||||
| entry[domain] = matched_ents | |||||
| if return_entry: | |||||
| return entry | |||||
| return match | |||||
| def pointerBack(self, vector, domain): | |||||
| # multi domain implementation | |||||
| # domnum = cfg.domain_num | |||||
| if domain.endswith(']'): | |||||
| domain = domain[1:-1] | |||||
| if domain != 'train': | |||||
| nummap = {0: '0', 1: '1', 2: '2-3', 3: '>3'} | |||||
| else: | |||||
| nummap = {0: '0', 1: '1-5', 2: '6-10', 3: '>10'} | |||||
| if vector[:4] == [0, 0, 0, 0]: | |||||
| report = '' | |||||
| else: | |||||
| num = vector.index(1) | |||||
| report = domain + ': ' + nummap[num] + '; ' | |||||
| if vector[-2] == 0 and vector[-1] == 1: | |||||
| report += 'booking: ok' | |||||
| if vector[-2] == 1 and vector[-1] == 0: | |||||
| report += 'booking: unable' | |||||
| return report | |||||
| def queryJsons(self, | |||||
| domain, | |||||
| constraints, | |||||
| exactly_match=True, | |||||
| return_name=False): | |||||
| """Returns the list of entities for a given domain | |||||
| based on the annotation of the belief state | |||||
| constraints: dict e.g. {'pricerange': 'cheap', 'area': 'west'} | |||||
| """ | |||||
| # query the db | |||||
| if domain == 'taxi': | |||||
| return [{ | |||||
| 'taxi_colors': | |||||
| random.choice(self.dbs[domain]['taxi_colors']), | |||||
| 'taxi_types': | |||||
| random.choice(self.dbs[domain]['taxi_types']), | |||||
| 'taxi_phone': [random.randint(1, 9) for _ in range(10)] | |||||
| }] | |||||
| if domain == 'police': | |||||
| return self.dbs['police'] | |||||
| if domain == 'hospital': | |||||
| if constraints.get('department'): | |||||
| for entry in self.dbs['hospital']: | |||||
| if entry.get('department') == constraints.get( | |||||
| 'department'): | |||||
| return [entry] | |||||
| else: | |||||
| return [] | |||||
| valid_cons = False | |||||
| for v in constraints.values(): | |||||
| if v not in ['not mentioned', '']: | |||||
| valid_cons = True | |||||
| if not valid_cons: | |||||
| return [] | |||||
| match_result = [] | |||||
| if 'name' in constraints: | |||||
| for db_ent in self.dbs[domain]: | |||||
| if 'name' in db_ent: | |||||
| cons = constraints['name'] | |||||
| dbn = db_ent['name'] | |||||
| if cons == dbn: | |||||
| db_ent = db_ent if not return_name else db_ent['name'] | |||||
| match_result.append(db_ent) | |||||
| return match_result | |||||
| for db_ent in self.dbs[domain]: | |||||
| match = True | |||||
| for s, v in constraints.items(): | |||||
| if s == 'name': | |||||
| continue | |||||
| if s in ['people', 'stay'] or (domain == 'hotel' and s == 'day') or \ | |||||
| (domain == 'restaurant' and s in ['day', 'time']): | |||||
| # These inform slots belong to "book info",which do not exist in DB | |||||
| # "book" is according to the user goal,not DB | |||||
| continue | |||||
| skip_case = { | |||||
| "don't care": 1, | |||||
| "do n't care": 1, | |||||
| 'dont care': 1, | |||||
| 'not mentioned': 1, | |||||
| 'dontcare': 1, | |||||
| '': 1 | |||||
| } | |||||
| if skip_case.get(v): | |||||
| continue | |||||
| if s not in db_ent: | |||||
| # logging.warning('Searching warning: slot %s not in %s db'%(s, domain)) | |||||
| match = False | |||||
| break | |||||
| # v = 'guesthouse' if v == 'guest house' else v | |||||
| # v = 'swimmingpool' if v == 'swimming pool' else v | |||||
| v = 'yes' if v == 'free' else v | |||||
| if s in ['arrive', 'leave']: | |||||
| try: | |||||
| h, m = v.split( | |||||
| ':' | |||||
| ) # raise error if time value is not xx:xx format | |||||
| v = int(h) * 60 + int(m) | |||||
| except Exception: | |||||
| match = False | |||||
| break | |||||
| time = int(db_ent[s].split(':')[0]) * 60 + int( | |||||
| db_ent[s].split(':')[1]) | |||||
| if s == 'arrive' and v > time: | |||||
| match = False | |||||
| if s == 'leave' and v < time: | |||||
| match = False | |||||
| else: | |||||
| if exactly_match and v != db_ent[s]: | |||||
| match = False | |||||
| break | |||||
| elif v not in db_ent[s]: | |||||
| match = False | |||||
| break | |||||
| if match: | |||||
| match_result.append(db_ent) | |||||
| if not return_name: | |||||
| return match_result | |||||
| else: | |||||
| if domain == 'train': | |||||
| match_result = [e['id'] for e in match_result] | |||||
| else: | |||||
| match_result = [e['name'] for e in match_result] | |||||
| return match_result | |||||
| def querySQL(self, domain, constraints): | |||||
| if not self.sql_dbs: | |||||
| for dom in db_domains: | |||||
| db = 'db/{}-dbase.db'.format(dom) | |||||
| conn = sqlite3.connect(db) | |||||
| c = conn.cursor() | |||||
| self.sql_dbs[dom] = c | |||||
| sql_query = 'select * from {}'.format(domain) | |||||
| flag = True | |||||
| for key, val in constraints.items(): | |||||
| if val == '' \ | |||||
| or val == 'dontcare' \ | |||||
| or val == 'not mentioned' \ | |||||
| or val == "don't care" \ | |||||
| or val == 'dont care' \ | |||||
| or val == "do n't care": | |||||
| pass | |||||
| else: | |||||
| if flag: | |||||
| sql_query += ' where ' | |||||
| val2 = val.replace("'", "''") | |||||
| # val2 = normalize(val2) | |||||
| if key == 'leaveAt': | |||||
| sql_query += r' ' + key + ' > ' + r"'" + val2 + r"'" | |||||
| elif key == 'arriveBy': | |||||
| sql_query += r' ' + key + ' < ' + r"'" + val2 + r"'" | |||||
| else: | |||||
| sql_query += r' ' + key + '=' + r"'" + val2 + r"'" | |||||
| flag = False | |||||
| else: | |||||
| val2 = val.replace("'", "''") | |||||
| # val2 = normalize(val2) | |||||
| if key == 'leaveAt': | |||||
| sql_query += r' and ' + key + ' > ' + r"'" + val2 + r"'" | |||||
| elif key == 'arriveBy': | |||||
| sql_query += r' and ' + key + ' < ' + r"'" + val2 + r"'" | |||||
| else: | |||||
| sql_query += r' and ' + key + '=' + r"'" + val2 + r"'" | |||||
| try: # "select * from attraction where name = 'queens college'" | |||||
| print(sql_query) | |||||
| return self.sql_dbs[domain].execute(sql_query).fetchall() | |||||
| except Exception: | |||||
| return [] # TODO test it | |||||
| if __name__ == '__main__': | |||||
| dbPATHs = { | |||||
| 'attraction': 'db/attraction_db_processed.json', | |||||
| 'hospital': 'db/hospital_db_processed.json', | |||||
| 'hotel': 'db/hotel_db_processed.json', | |||||
| 'police': 'db/police_db_processed.json', | |||||
| 'restaurant': 'db/restaurant_db_processed.json', | |||||
| 'taxi': 'db/taxi_db_processed.json', | |||||
| 'train': 'db/train_db_processed.json', | |||||
| } | |||||
| db = MultiWozDB(dbPATHs) | |||||
| while True: | |||||
| constraints = {} | |||||
| inp = input( | |||||
| 'input belief state in fomat: domain-slot1=value1;slot2=value2...\n' | |||||
| ) | |||||
| domain, cons = inp.split('-') | |||||
| for sv in cons.split(';'): | |||||
| s, v = sv.split('=') | |||||
| constraints[s] = v | |||||
| # res = db.querySQL(domain, constraints) | |||||
| res = db.queryJsons(domain, constraints, return_name=True) | |||||
| report = [] | |||||
| reidx = { | |||||
| 'hotel': 8, | |||||
| 'restaurant': 6, | |||||
| 'attraction': 5, | |||||
| 'train': 1, | |||||
| } | |||||
| print(constraints) | |||||
| print(res) | |||||
| print('count:', len(res), '\nnames:', report) | |||||
| @@ -0,0 +1,204 @@ | |||||
| all_domains = [ | |||||
| 'restaurant', 'hotel', 'attraction', 'train', 'taxi', 'police', 'hospital' | |||||
| ] | |||||
| all_domains_with_bracket = ['[{}]'.format(item) for item in all_domains] | |||||
| db_domains = ['restaurant', 'hotel', 'attraction', 'train'] | |||||
| placeholder_tokens = [ | |||||
| '<go_r>', '<go_b>', '<go_a>', '<go_d>', '<eos_u>', '<eos_r>', '<eos_b>', | |||||
| '<eos_a>', '<eos_d>', '<eos_q>', '<sos_u>', '<sos_r>', '<sos_b>', | |||||
| '<sos_a>', '<sos_d>', '<sos_q>' | |||||
| ] | |||||
| normlize_slot_names = { | |||||
| 'car type': 'car', | |||||
| 'entrance fee': 'price', | |||||
| 'duration': 'time', | |||||
| 'leaveat': 'leave', | |||||
| 'arriveby': 'arrive', | |||||
| 'trainid': 'id' | |||||
| } | |||||
| requestable_slots = { | |||||
| 'taxi': ['car', 'phone'], | |||||
| 'police': ['postcode', 'address', 'phone'], | |||||
| 'hospital': ['address', 'phone', 'postcode'], | |||||
| 'hotel': [ | |||||
| 'address', 'postcode', 'internet', 'phone', 'parking', 'type', | |||||
| 'pricerange', 'stars', 'area', 'reference' | |||||
| ], | |||||
| 'attraction': | |||||
| ['price', 'type', 'address', 'postcode', 'phone', 'area', 'reference'], | |||||
| 'train': ['time', 'leave', 'price', 'arrive', 'id', 'reference'], | |||||
| 'restaurant': [ | |||||
| 'phone', 'postcode', 'address', 'pricerange', 'food', 'area', | |||||
| 'reference' | |||||
| ] | |||||
| } | |||||
| all_reqslot = [ | |||||
| 'car', 'address', 'postcode', 'phone', 'internet', 'parking', 'type', | |||||
| 'pricerange', 'food', 'stars', 'area', 'reference', 'time', 'leave', | |||||
| 'price', 'arrive', 'id' | |||||
| ] | |||||
| informable_slots = { | |||||
| 'taxi': ['leave', 'destination', 'departure', 'arrive'], | |||||
| 'police': [], | |||||
| 'hospital': ['department'], | |||||
| 'hotel': [ | |||||
| 'type', 'parking', 'pricerange', 'internet', 'stay', 'day', 'people', | |||||
| 'area', 'stars', 'name' | |||||
| ], | |||||
| 'attraction': ['area', 'type', 'name'], | |||||
| 'train': ['destination', 'day', 'arrive', 'departure', 'people', 'leave'], | |||||
| 'restaurant': | |||||
| ['food', 'pricerange', 'area', 'name', 'time', 'day', 'people'] | |||||
| } | |||||
| all_infslot = [ | |||||
| 'type', 'parking', 'pricerange', 'internet', 'stay', 'day', 'people', | |||||
| 'area', 'stars', 'name', 'leave', 'destination', 'departure', 'arrive', | |||||
| 'department', 'food', 'time' | |||||
| ] | |||||
| all_slots = all_reqslot + [ | |||||
| 'stay', 'day', 'people', 'name', 'destination', 'departure', 'department' | |||||
| ] | |||||
| get_slot = {} | |||||
| for s in all_slots: | |||||
| get_slot[s] = 1 | |||||
| # mapping slots in dialogue act to original goal slot names | |||||
| da_abbr_to_slot_name = { | |||||
| 'addr': 'address', | |||||
| 'fee': 'price', | |||||
| 'post': 'postcode', | |||||
| 'ref': 'reference', | |||||
| 'ticket': 'price', | |||||
| 'depart': 'departure', | |||||
| 'dest': 'destination', | |||||
| } | |||||
| dialog_acts = { | |||||
| 'restaurant': [ | |||||
| 'inform', 'request', 'nooffer', 'recommend', 'select', 'offerbook', | |||||
| 'offerbooked', 'nobook' | |||||
| ], | |||||
| 'hotel': [ | |||||
| 'inform', 'request', 'nooffer', 'recommend', 'select', 'offerbook', | |||||
| 'offerbooked', 'nobook' | |||||
| ], | |||||
| 'attraction': ['inform', 'request', 'nooffer', 'recommend', 'select'], | |||||
| 'train': | |||||
| ['inform', 'request', 'nooffer', 'offerbook', 'offerbooked', 'select'], | |||||
| 'taxi': ['inform', 'request'], | |||||
| 'police': ['inform', 'request'], | |||||
| 'hospital': ['inform', 'request'], | |||||
| # 'booking': ['book', 'inform', 'nobook', 'request'], | |||||
| 'general': ['bye', 'greet', 'reqmore', 'welcome'], | |||||
| } | |||||
| all_acts = [] | |||||
| for acts in dialog_acts.values(): | |||||
| for act in acts: | |||||
| if act not in all_acts: | |||||
| all_acts.append(act) | |||||
| dialog_act_params = { | |||||
| 'inform': all_slots + ['choice', 'open'], | |||||
| 'request': all_infslot + ['choice', 'price'], | |||||
| 'nooffer': all_slots + ['choice'], | |||||
| 'recommend': all_reqslot + ['choice', 'open'], | |||||
| 'select': all_slots + ['choice'], | |||||
| # 'book': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'], | |||||
| 'nobook': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'], | |||||
| 'offerbook': all_slots + ['choice'], | |||||
| 'offerbooked': all_slots + ['choice'], | |||||
| 'reqmore': [], | |||||
| 'welcome': [], | |||||
| 'bye': [], | |||||
| 'greet': [], | |||||
| } | |||||
| dialog_act_all_slots = all_slots + ['choice', 'open'] | |||||
| # special slot tokens in belief span | |||||
| # no need of this, just covert slot to [slot] e.g. pricerange -> [pricerange] | |||||
| slot_name_to_slot_token = {} | |||||
| # eos tokens definition | |||||
| eos_tokens = { | |||||
| 'user': '<eos_u>', | |||||
| 'user_delex': '<eos_u>', | |||||
| 'resp': '<eos_r>', | |||||
| 'resp_gen': '<eos_r>', | |||||
| 'pv_resp': '<eos_r>', | |||||
| 'bspn': '<eos_b>', | |||||
| 'bspn_gen': '<eos_b>', | |||||
| 'pv_bspn': '<eos_b>', | |||||
| 'bsdx': '<eos_b>', | |||||
| 'bsdx_gen': '<eos_b>', | |||||
| 'pv_bsdx': '<eos_b>', | |||||
| 'qspn': '<eos_q>', | |||||
| 'qspn_gen': '<eos_q>', | |||||
| 'pv_qspn': '<eos_q>', | |||||
| 'aspn': '<eos_a>', | |||||
| 'aspn_gen': '<eos_a>', | |||||
| 'pv_aspn': '<eos_a>', | |||||
| 'dspn': '<eos_d>', | |||||
| 'dspn_gen': '<eos_d>', | |||||
| 'pv_dspn': '<eos_d>' | |||||
| } | |||||
| # sos tokens definition | |||||
| sos_tokens = { | |||||
| 'user': '<sos_u>', | |||||
| 'user_delex': '<sos_u>', | |||||
| 'resp': '<sos_r>', | |||||
| 'resp_gen': '<sos_r>', | |||||
| 'pv_resp': '<sos_r>', | |||||
| 'bspn': '<sos_b>', | |||||
| 'bspn_gen': '<sos_b>', | |||||
| 'pv_bspn': '<sos_b>', | |||||
| 'bsdx': '<sos_b>', | |||||
| 'bsdx_gen': '<sos_b>', | |||||
| 'pv_bsdx': '<sos_b>', | |||||
| 'qspn': '<sos_q>', | |||||
| 'qspn_gen': '<sos_q>', | |||||
| 'pv_qspn': '<sos_q>', | |||||
| 'aspn': '<sos_a>', | |||||
| 'aspn_gen': '<sos_a>', | |||||
| 'pv_aspn': '<sos_a>', | |||||
| 'dspn': '<sos_d>', | |||||
| 'dspn_gen': '<sos_d>', | |||||
| 'pv_dspn': '<sos_d>' | |||||
| } | |||||
| # db tokens definition | |||||
| db_tokens = [ | |||||
| '<sos_db>', '<eos_db>', '[book_nores]', '[book_fail]', '[book_success]', | |||||
| '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]' | |||||
| ] | |||||
| # understand tokens definition | |||||
| def get_understand_tokens(prompt_num_for_understand): | |||||
| understand_tokens = [] | |||||
| for i in range(prompt_num_for_understand): | |||||
| understand_tokens.append(f'<understand_{i}>') | |||||
| return understand_tokens | |||||
| # policy tokens definition | |||||
| def get_policy_tokens(prompt_num_for_policy): | |||||
| policy_tokens = [] | |||||
| for i in range(prompt_num_for_policy): | |||||
| policy_tokens.append(f'<policy_{i}>') | |||||
| return policy_tokens | |||||
| # all special tokens definition | |||||
| def get_special_tokens(other_tokens): | |||||
| special_tokens = [ | |||||
| '<go_r>', '<go_b>', '<go_a>', '<go_d>', '<eos_u>', '<eos_r>', | |||||
| '<eos_b>', '<eos_a>', '<eos_d>', '<eos_q>', '<sos_u>', '<sos_r>', | |||||
| '<sos_b>', '<sos_a>', '<sos_d>', '<sos_q>' | |||||
| ] + db_tokens + other_tokens | |||||
| return special_tokens | |||||
| @@ -0,0 +1,6 @@ | |||||
| def hierarchical_set_score(frame1, frame2): | |||||
| # deal with empty frame | |||||
| if not (frame1 and frame2): | |||||
| return 0. | |||||
| pass | |||||
| return 0. | |||||
| @@ -0,0 +1,188 @@ | |||||
| import logging | |||||
| from collections import OrderedDict | |||||
| import json | |||||
| import numpy as np | |||||
| from . import ontology | |||||
| def max_lens(X): | |||||
| lens = [len(X)] | |||||
| while isinstance(X[0], list): | |||||
| lens.append(max(map(len, X))) | |||||
| X = [x for xs in X for x in xs] | |||||
| return lens | |||||
| def list2np(X: object, padding: object = 0, dtype: object = 'int64') -> object: | |||||
| shape = max_lens(X) | |||||
| ret = np.full(shape, padding, dtype=np.int32) | |||||
| if len(shape) == 1: | |||||
| ret = np.array(X) | |||||
| elif len(shape) == 2: | |||||
| for i, x in enumerate(X): | |||||
| ret[i, :len(x)] = np.array(x) | |||||
| elif len(shape) == 3: | |||||
| for i, xs in enumerate(X): | |||||
| for j, x in enumerate(xs): | |||||
| ret[i, j, :len(x)] = np.array(x) | |||||
| return ret.astype(dtype) | |||||
| def clean_replace(s, r, t, forward=True, backward=False): | |||||
| def clean_replace_single(s, r, t, forward, backward, sidx=0): | |||||
| # idx = s[sidx:].find(r) | |||||
| idx = s.find(r) | |||||
| if idx == -1: | |||||
| return s, -1 | |||||
| idx_r = idx + len(r) | |||||
| if backward: | |||||
| while idx > 0 and s[idx - 1]: | |||||
| idx -= 1 | |||||
| elif idx > 0 and s[idx - 1] != ' ': | |||||
| return s, -1 | |||||
| if forward: | |||||
| while \ | |||||
| idx_r < len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()): | |||||
| idx_r += 1 | |||||
| elif idx_r != len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()): | |||||
| return s, -1 | |||||
| return s[:idx] + t + s[idx_r:], idx_r | |||||
| sidx = 0 | |||||
| while sidx != -1: | |||||
| s, sidx = clean_replace_single(s, r, t, forward, backward, sidx) | |||||
| return s | |||||
| def py2np(list): | |||||
| return np.array(list) | |||||
| def write_dict(fn, dic): | |||||
| with open(fn, 'w') as f: | |||||
| json.dump(dic, f, indent=2) | |||||
| def f1_score(label_list, pred_list): | |||||
| tp = len([t for t in pred_list if t in label_list]) | |||||
| fp = max(0, len(pred_list) - tp) | |||||
| fn = max(0, len(label_list) - tp) | |||||
| precision = tp / (tp + fp + 1e-10) | |||||
| recall = tp / (tp + fn + 1e-10) | |||||
| f1 = 2 * precision * recall / (precision + recall + 1e-10) | |||||
| return f1 | |||||
| class MultiWOZVocab(object): | |||||
| def __init__(self, vocab_size=0): | |||||
| """ | |||||
| vocab for multiwoz dataset | |||||
| """ | |||||
| self.vocab_size = vocab_size | |||||
| self.vocab_size_oov = 0 # get after construction | |||||
| self._idx2word = {} # word + oov | |||||
| self._word2idx = {} # word | |||||
| self._freq_dict = {} # word + oov | |||||
| for w in [ | |||||
| '[PAD]', '<go_r>', '[UNK]', '<go_b>', '<go_a>', '<eos_u>', | |||||
| '<eos_r>', '<eos_b>', '<eos_a>', '<go_d>', '<eos_d>' | |||||
| ]: | |||||
| self._absolute_add_word(w) | |||||
| def _absolute_add_word(self, w): | |||||
| idx = len(self._idx2word) | |||||
| self._idx2word[idx] = w | |||||
| self._word2idx[w] = idx | |||||
| def add_word(self, word): | |||||
| if word not in self._freq_dict: | |||||
| self._freq_dict[word] = 0 | |||||
| self._freq_dict[word] += 1 | |||||
| def has_word(self, word): | |||||
| return self._freq_dict.get(word) | |||||
| def _add_to_vocab(self, word): | |||||
| if word not in self._word2idx: | |||||
| idx = len(self._idx2word) | |||||
| self._idx2word[idx] = word | |||||
| self._word2idx[word] = idx | |||||
| def construct(self): | |||||
| freq_dict_sorted = sorted( | |||||
| self._freq_dict.keys(), key=lambda x: -self._freq_dict[x]) | |||||
| print('Vocabulary size including oov: %d' % | |||||
| (len(freq_dict_sorted) + len(self._idx2word))) | |||||
| if len(freq_dict_sorted) + len(self._idx2word) < self.vocab_size: | |||||
| logging.warning( | |||||
| 'actual label set smaller than that configured: {}/{}'.format( | |||||
| len(freq_dict_sorted) + len(self._idx2word), | |||||
| self.vocab_size)) | |||||
| for word in ontology.all_domains + ['general']: | |||||
| word = '[' + word + ']' | |||||
| self._add_to_vocab(word) | |||||
| for word in ontology.all_acts: | |||||
| word = '[' + word + ']' | |||||
| self._add_to_vocab(word) | |||||
| for word in ontology.all_slots: | |||||
| self._add_to_vocab(word) | |||||
| for word in freq_dict_sorted: | |||||
| if word.startswith('[value_') and word.endswith(']'): | |||||
| self._add_to_vocab(word) | |||||
| for word in freq_dict_sorted: | |||||
| self._add_to_vocab(word) | |||||
| self.vocab_size_oov = len(self._idx2word) | |||||
| def load_vocab(self, vocab_path): | |||||
| self._freq_dict = json.loads( | |||||
| open(vocab_path + '.freq.json', 'r').read()) | |||||
| self._word2idx = json.loads( | |||||
| open(vocab_path + '.word2idx.json', 'r').read()) | |||||
| self._idx2word = {} | |||||
| for w, idx in self._word2idx.items(): | |||||
| self._idx2word[idx] = w | |||||
| self.vocab_size_oov = len(self._idx2word) | |||||
| print('vocab file loaded from "' + vocab_path + '"') | |||||
| print('Vocabulary size including oov: %d' % (self.vocab_size_oov)) | |||||
| def save_vocab(self, vocab_path): | |||||
| _freq_dict = OrderedDict( | |||||
| sorted( | |||||
| self._freq_dict.items(), key=lambda kv: kv[1], reverse=True)) | |||||
| write_dict(vocab_path + '.word2idx.json', self._word2idx) | |||||
| write_dict(vocab_path + '.freq.json', _freq_dict) | |||||
| def encode(self, word, include_oov=True): | |||||
| if include_oov: | |||||
| if self._word2idx.get(word, None) is None: | |||||
| raise ValueError( | |||||
| 'Unknown word: %s. Vocabulary should include oovs here.' | |||||
| % word) | |||||
| return self._word2idx[word] | |||||
| else: | |||||
| word = '<unk>' if word not in self._word2idx else word | |||||
| return self._word2idx[word] | |||||
| def sentence_encode(self, word_list): | |||||
| return [self.encode(_) for _ in word_list] | |||||
| def oov_idx_map(self, idx): | |||||
| return 2 if idx > self.vocab_size else idx | |||||
| def sentence_oov_map(self, index_list): | |||||
| return [self.oov_idx_map(_) for _ in index_list] | |||||
| def decode(self, idx, indicate_oov=False): | |||||
| if not self._idx2word.get(idx): | |||||
| raise ValueError( | |||||
| 'Error idx: %d. Vocabulary should include oovs here.' % idx) | |||||
| if not indicate_oov or idx < self.vocab_size: | |||||
| return self._idx2word[idx] | |||||
| else: | |||||
| return self._idx2word[idx] + '(o)' | |||||
| @@ -1 +1,3 @@ | |||||
| sofa==1.0.4.2 | |||||
| https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz | |||||
| sofa==1.0.5 | |||||
| spacy>=2.3.5 | |||||
| @@ -0,0 +1,61 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.nlp import SpaceForDialogIntent | |||||
| from modelscope.pipelines import DialogIntentPredictionPipeline, pipeline | |||||
| from modelscope.preprocessors import DialogIntentPredictionPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class DialogIntentPredictionTest(unittest.TestCase): | |||||
| model_id = 'damo/nlp_space_dialog-intent-prediction' | |||||
| test_case = [ | |||||
| 'How do I locate my card?', | |||||
| 'I still have not received my new card, I ordered over a week ago.' | |||||
| ] | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run(self): | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path) | |||||
| model = SpaceForDialogIntent( | |||||
| model_dir=cache_path, | |||||
| text_field=preprocessor.text_field, | |||||
| config=preprocessor.config) | |||||
| pipelines = [ | |||||
| DialogIntentPredictionPipeline( | |||||
| model=model, preprocessor=preprocessor), | |||||
| pipeline( | |||||
| task=Tasks.dialog_intent_prediction, | |||||
| model=model, | |||||
| preprocessor=preprocessor) | |||||
| ] | |||||
| for my_pipeline, item in list(zip(pipelines, self.test_case)): | |||||
| print(my_pipeline(item)) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| preprocessor = DialogIntentPredictionPreprocessor( | |||||
| model_dir=model.model_dir) | |||||
| pipelines = [ | |||||
| DialogIntentPredictionPipeline( | |||||
| model=model, preprocessor=preprocessor), | |||||
| pipeline( | |||||
| task=Tasks.dialog_intent_prediction, | |||||
| model=model, | |||||
| preprocessor=preprocessor) | |||||
| ] | |||||
| for my_pipeline, item in list(zip(pipelines, self.test_case)): | |||||
| print(my_pipeline(item)) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -0,0 +1,147 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.nlp import SpaceForDialogModeling | |||||
| from modelscope.pipelines import DialogModelingPipeline, pipeline | |||||
| from modelscope.preprocessors import DialogModelingPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class DialogModelingTest(unittest.TestCase): | |||||
| model_id = 'damo/nlp_space_dialog-modeling' | |||||
| test_case = { | |||||
| 'sng0073': { | |||||
| 'goal': { | |||||
| 'taxi': { | |||||
| 'info': { | |||||
| 'leaveat': '17:15', | |||||
| 'destination': 'pizza hut fen ditton', | |||||
| 'departure': "saint john's college" | |||||
| }, | |||||
| 'reqt': ['car', 'phone'], | |||||
| 'fail_info': {} | |||||
| } | |||||
| }, | |||||
| 'log': [{ | |||||
| 'user': | |||||
| "i would like a taxi from saint john 's college to pizza hut fen ditton .", | |||||
| 'user_delex': | |||||
| 'i would like a taxi from [value_departure] to [value_destination] .', | |||||
| 'resp': | |||||
| 'what time do you want to leave and what time do you want to arrive by ?', | |||||
| 'sys': | |||||
| 'what time do you want to leave and what time do you want to arrive by ?', | |||||
| 'pointer': '0,0,0,0,0,0', | |||||
| 'match': '', | |||||
| 'constraint': | |||||
| "[taxi] destination pizza hut fen ditton departure saint john 's college", | |||||
| 'cons_delex': '[taxi] destination departure', | |||||
| 'sys_act': '[taxi] [request] leave arrive', | |||||
| 'turn_num': 0, | |||||
| 'turn_domain': '[taxi]' | |||||
| }, { | |||||
| 'user': 'i want to leave after 17:15 .', | |||||
| 'user_delex': 'i want to leave after [value_leave] .', | |||||
| 'resp': | |||||
| 'booking completed ! your taxi will be [value_car] contact number is [value_phone]', | |||||
| 'sys': | |||||
| 'booking completed ! your taxi will be blue honda contact number is 07218068540', | |||||
| 'pointer': '0,0,0,0,0,0', | |||||
| 'match': '', | |||||
| 'constraint': | |||||
| "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", | |||||
| 'cons_delex': '[taxi] destination departure leave', | |||||
| 'sys_act': '[taxi] [inform] car phone', | |||||
| 'turn_num': 1, | |||||
| 'turn_domain': '[taxi]' | |||||
| }, { | |||||
| 'user': 'thank you for all the help ! i appreciate it .', | |||||
| 'user_delex': 'thank you for all the help ! i appreciate it .', | |||||
| 'resp': | |||||
| 'you are welcome . is there anything else i can help you with today ?', | |||||
| 'sys': | |||||
| 'you are welcome . is there anything else i can help you with today ?', | |||||
| 'pointer': '0,0,0,0,0,0', | |||||
| 'match': '', | |||||
| 'constraint': | |||||
| "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", | |||||
| 'cons_delex': '[taxi] destination departure leave', | |||||
| 'sys_act': '[general] [reqmore]', | |||||
| 'turn_num': 2, | |||||
| 'turn_domain': '[general]' | |||||
| }, { | |||||
| 'user': 'no , i am all set . have a nice day . bye .', | |||||
| 'user_delex': 'no , i am all set . have a nice day . bye .', | |||||
| 'resp': 'you too ! thank you', | |||||
| 'sys': 'you too ! thank you', | |||||
| 'pointer': '0,0,0,0,0,0', | |||||
| 'match': '', | |||||
| 'constraint': | |||||
| "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", | |||||
| 'cons_delex': '[taxi] destination departure leave', | |||||
| 'sys_act': '[general] [bye]', | |||||
| 'turn_num': 3, | |||||
| 'turn_domain': '[general]' | |||||
| }] | |||||
| } | |||||
| } | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run(self): | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| preprocessor = DialogModelingPreprocessor(model_dir=cache_path) | |||||
| model = SpaceForDialogModeling( | |||||
| model_dir=cache_path, | |||||
| text_field=preprocessor.text_field, | |||||
| config=preprocessor.config) | |||||
| pipelines = [ | |||||
| DialogModelingPipeline(model=model, preprocessor=preprocessor), | |||||
| pipeline( | |||||
| task=Tasks.dialog_modeling, | |||||
| model=model, | |||||
| preprocessor=preprocessor) | |||||
| ] | |||||
| result = {} | |||||
| for step, item in enumerate(self.test_case['sng0073']['log']): | |||||
| user = item['user'] | |||||
| print('user: {}'.format(user)) | |||||
| result = pipelines[step % 2]({ | |||||
| 'user_input': user, | |||||
| 'history': result | |||||
| }) | |||||
| print('response : {}'.format(result['response'])) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| preprocessor = DialogModelingPreprocessor(model_dir=model.model_dir) | |||||
| pipelines = [ | |||||
| DialogModelingPipeline(model=model, preprocessor=preprocessor), | |||||
| pipeline( | |||||
| task=Tasks.dialog_modeling, | |||||
| model=model, | |||||
| preprocessor=preprocessor) | |||||
| ] | |||||
| result = {} | |||||
| for step, item in enumerate(self.test_case['sng0073']['log']): | |||||
| user = item['user'] | |||||
| print('user: {}'.format(user)) | |||||
| result = pipelines[step % 2]({ | |||||
| 'user_input': user, | |||||
| 'history': result | |||||
| }) | |||||
| print('response : {}'.format(result['response'])) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -0,0 +1,52 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.nlp import SbertForNLI | |||||
| from modelscope.pipelines import NLIPipeline, pipeline | |||||
| from modelscope.preprocessors import NLIPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class NLITest(unittest.TestCase): | |||||
| model_id = 'damo/nlp_structbert_nli_chinese-base' | |||||
| sentence1 = '四川商务职业学院和四川财经职业学院哪个好?' | |||||
| sentence2 = '四川商务职业学院商务管理在哪个校区?' | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_direct_file_download(self): | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| tokenizer = NLIPreprocessor(cache_path) | |||||
| model = SbertForNLI(cache_path, tokenizer=tokenizer) | |||||
| pipeline1 = NLIPipeline(model, preprocessor=tokenizer) | |||||
| pipeline2 = pipeline(Tasks.nli, model=model, preprocessor=tokenizer) | |||||
| print(f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' | |||||
| f'pipeline1:{pipeline1(input=(self.sentence1, self.sentence2))}') | |||||
| print() | |||||
| print( | |||||
| f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' | |||||
| f'pipeline1: {pipeline2(input=(self.sentence1, self.sentence2))}') | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| tokenizer = NLIPreprocessor(model.model_dir) | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.nli, model=model, preprocessor=tokenizer) | |||||
| print(pipeline_ins(input=(self.sentence1, self.sentence2))) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_name(self): | |||||
| pipeline_ins = pipeline(task=Tasks.nli, model=self.model_id) | |||||
| print(pipeline_ins(input=(self.sentence1, self.sentence2))) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_default_model(self): | |||||
| pipeline_ins = pipeline(task=Tasks.nli) | |||||
| print(pipeline_ins(input=(self.sentence1, self.sentence2))) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -0,0 +1,58 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.nlp import SbertForSentimentClassification | |||||
| from modelscope.pipelines import SentimentClassificationPipeline, pipeline | |||||
| from modelscope.preprocessors import SentimentClassificationPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class SentimentClassificationTest(unittest.TestCase): | |||||
| model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | |||||
| sentence1 = '启动的时候很大声音,然后就会听到1.2秒的卡察的声音,类似齿轮摩擦的声音' | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_direct_file_download(self): | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| tokenizer = SentimentClassificationPreprocessor(cache_path) | |||||
| model = SbertForSentimentClassification( | |||||
| cache_path, tokenizer=tokenizer) | |||||
| pipeline1 = SentimentClassificationPipeline( | |||||
| model, preprocessor=tokenizer) | |||||
| pipeline2 = pipeline( | |||||
| Tasks.sentiment_classification, | |||||
| model=model, | |||||
| preprocessor=tokenizer) | |||||
| print(f'sentence1: {self.sentence1}\n' | |||||
| f'pipeline1:{pipeline1(input=self.sentence1)}') | |||||
| print() | |||||
| print(f'sentence1: {self.sentence1}\n' | |||||
| f'pipeline1: {pipeline2(input=self.sentence1)}') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| tokenizer = SentimentClassificationPreprocessor(model.model_dir) | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.sentiment_classification, | |||||
| model=model, | |||||
| preprocessor=tokenizer) | |||||
| print(pipeline_ins(input=self.sentence1)) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_model_name(self): | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.sentiment_classification, model=self.model_id) | |||||
| print(pipeline_ins(input=self.sentence1)) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_default_model(self): | |||||
| pipeline_ins = pipeline(task=Tasks.sentiment_classification) | |||||
| print(pipeline_ins(input=self.sentence1)) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -4,7 +4,7 @@ import unittest | |||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.models import Model | from modelscope.models import Model | ||||
| from modelscope.models.nlp import StructBertForTokenClassification | |||||
| from modelscope.models.nlp import SbertForTokenClassification | |||||
| from modelscope.pipelines import WordSegmentationPipeline, pipeline | from modelscope.pipelines import WordSegmentationPipeline, pipeline | ||||
| from modelscope.preprocessors import TokenClassifcationPreprocessor | from modelscope.preprocessors import TokenClassifcationPreprocessor | ||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| @@ -19,8 +19,7 @@ class WordSegmentationTest(unittest.TestCase): | |||||
| def test_run_by_direct_model_download(self): | def test_run_by_direct_model_download(self): | ||||
| cache_path = snapshot_download(self.model_id) | cache_path = snapshot_download(self.model_id) | ||||
| tokenizer = TokenClassifcationPreprocessor(cache_path) | tokenizer = TokenClassifcationPreprocessor(cache_path) | ||||
| model = StructBertForTokenClassification( | |||||
| cache_path, tokenizer=tokenizer) | |||||
| model = SbertForTokenClassification(cache_path, tokenizer=tokenizer) | |||||
| pipeline1 = WordSegmentationPipeline(model, preprocessor=tokenizer) | pipeline1 = WordSegmentationPipeline(model, preprocessor=tokenizer) | ||||
| pipeline2 = pipeline( | pipeline2 = pipeline( | ||||
| Tasks.word_segmentation, model=model, preprocessor=tokenizer) | Tasks.word_segmentation, model=model, preprocessor=tokenizer) | ||||