添加了,nli,sentiment_classification, dialog_intent, dialog_modeling几个pipeline。同时加入了nlp里面sequence classification一些简单的抽象。 去掉了zero_shot_classification Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9159089master
| @@ -16,6 +16,7 @@ class Models(object): | |||
| palm = 'palm-v2' | |||
| structbert = 'structbert' | |||
| veco = 'veco' | |||
| space = 'space' | |||
| # audio models | |||
| sambert_hifi_16k = 'sambert-hifi-16k' | |||
| @@ -52,7 +53,11 @@ class Pipelines(object): | |||
| word_segmentation = 'word-segmentation' | |||
| text_generation = 'text-generation' | |||
| sentiment_analysis = 'sentiment-analysis' | |||
| sentiment_classification = 'sentiment-classification' | |||
| fill_mask = 'fill-mask' | |||
| nli = 'nli' | |||
| dialog_intent_prediction = 'dialog-intent-prediction' | |||
| dialog_modeling = 'dialog-modeling' | |||
| zero_shot_classification = 'zero-shot-classification' | |||
| # audio tasks | |||
| @@ -97,6 +102,11 @@ class Preprocessors(object): | |||
| # nlp preprocessor | |||
| bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' | |||
| palm_text_gen_tokenizer = 'palm-text-gen-tokenizer' | |||
| token_cls_tokenizer = 'token-cls-tokenizer' | |||
| nli_tokenizer = 'nli-tokenizer' | |||
| sen_cls_tokenizer = 'sen-cls-tokenizer' | |||
| dialog_intent_preprocessor = 'dialog-intent-preprocessor' | |||
| dialog_modeling_preprocessor = 'dialog-modeling-preprocessor' | |||
| sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' | |||
| zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | |||
| @@ -15,9 +15,13 @@ except ModuleNotFoundError as e: | |||
| try: | |||
| from .audio.kws import GenericKeyWordSpotting | |||
| from .multi_modal import OfaForImageCaptioning | |||
| from .nlp import (BertForSequenceClassification, | |||
| SbertForSentenceSimilarity, | |||
| SbertForZeroShotClassification) | |||
| from .nlp import (BertForMaskedLM, BertForSequenceClassification, | |||
| SbertForNLI, SbertForSentenceSimilarity, | |||
| SbertForSentimentClassification, | |||
| SbertForTokenClassification, | |||
| SbertForZeroShotClassification, SpaceForDialogIntent, | |||
| SpaceForDialogModeling, StructBertForMaskedLM, | |||
| VecoForMaskedLM) | |||
| from .audio.ans.frcrn import FRCRNModel | |||
| except ModuleNotFoundError as e: | |||
| if str(e) == "No module named 'pytorch'": | |||
| @@ -1,6 +1,10 @@ | |||
| from .bert_for_sequence_classification import * # noqa F403 | |||
| from .masked_language_model import * # noqa F403 | |||
| from .palm_for_text_generation import * # noqa F403 | |||
| from .sbert_for_nli import * # noqa F403 | |||
| from .sbert_for_sentence_similarity import * # noqa F403 | |||
| from .sbert_for_sentiment_classification import * # noqa F403 | |||
| from .sbert_for_token_classification import * # noqa F403 | |||
| from .sbert_for_zero_shot_classification import * # noqa F403 | |||
| from .space.dialog_intent_prediction_model import * # noqa F403 | |||
| from .space.dialog_modeling_model import * # noqa F403 | |||
| @@ -4,8 +4,8 @@ from typing import Any, Dict | |||
| import json | |||
| import numpy as np | |||
| from modelscope.metainfo import Models | |||
| from modelscope.utils.constant import Tasks | |||
| from ...metainfo import Models | |||
| from ...utils.constant import Tasks | |||
| from ..base import Model | |||
| from ..builder import MODELS | |||
| @@ -16,16 +16,22 @@ class MaskedLanguageModelBase(Model): | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self.model = self.build_model() | |||
| def build_model(): | |||
| def build_model(self): | |||
| raise NotImplementedError() | |||
| def train(self): | |||
| return self.model.train() | |||
| def eval(self): | |||
| return self.model.eval() | |||
| @property | |||
| def config(self): | |||
| if hasattr(self.model, 'config'): | |||
| return self.model.config | |||
| return None | |||
| def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, np.ndarray]: | |||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, np.ndarray]: | |||
| """return the result by the model | |||
| Args: | |||
| @@ -35,10 +41,10 @@ class MaskedLanguageModelBase(Model): | |||
| Dict[str, np.ndarray]: results | |||
| """ | |||
| rst = self.model( | |||
| input_ids=inputs['input_ids'], | |||
| attention_mask=inputs['attention_mask'], | |||
| token_type_ids=inputs['token_type_ids']) | |||
| return {'logits': rst['logits'], 'input_ids': inputs['input_ids']} | |||
| input_ids=input['input_ids'], | |||
| attention_mask=input['attention_mask'], | |||
| token_type_ids=input['token_type_ids']) | |||
| return {'logits': rst['logits'], 'input_ids': input['input_ids']} | |||
| @MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert) | |||
| @@ -1,7 +1,7 @@ | |||
| from typing import Dict | |||
| from modelscope.metainfo import Models | |||
| from modelscope.utils.constant import Tasks | |||
| from ...metainfo import Models | |||
| from ...utils.constant import Tasks | |||
| from ..base import Model, Tensor | |||
| from ..builder import MODELS | |||
| @@ -20,13 +20,18 @@ class PalmForTextGeneration(Model): | |||
| default loader to load model weights, by default None. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self.model_dir = model_dir | |||
| from sofa.models.palm_v2 import PalmForConditionalGeneration, Translator | |||
| model = PalmForConditionalGeneration.from_pretrained(model_dir) | |||
| self.tokenizer = model.tokenizer | |||
| self.generator = Translator(model) | |||
| def train(self): | |||
| return self.generator.train() | |||
| def eval(self): | |||
| return self.generator.eval() | |||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| """return the result by the model | |||
| @@ -0,0 +1,23 @@ | |||
| from ...metainfo import Models | |||
| from ...utils.constant import Tasks | |||
| from ..builder import MODELS | |||
| from .sbert_for_sequence_classification import \ | |||
| SbertForSequenceClassificationBase | |||
| __all__ = ['SbertForNLI'] | |||
| @MODELS.register_module(Tasks.nli, module_name=Models.structbert) | |||
| class SbertForNLI(SbertForSequenceClassificationBase): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the text generation model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| model_cls (Optional[Any], optional): model loader, if None, use the | |||
| default loader to load model weights, by default None. | |||
| """ | |||
| super().__init__( | |||
| model_dir, *args, model_args={'num_labels': 3}, **kwargs) | |||
| assert self.model.config.num_labels == 3 | |||
| @@ -1,46 +1,15 @@ | |||
| import os | |||
| from typing import Any, Dict | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| from sofa import SbertModel | |||
| from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel | |||
| from torch import nn | |||
| from modelscope.metainfo import Models | |||
| from modelscope.utils.constant import Tasks | |||
| from ..base import Model, Tensor | |||
| from ...metainfo import Models | |||
| from ...utils.constant import Tasks | |||
| from ..builder import MODELS | |||
| from .sbert_for_sequence_classification import \ | |||
| SbertForSequenceClassificationBase | |||
| __all__ = ['SbertForSentenceSimilarity'] | |||
| class SbertTextClassifier(SbertPreTrainedModel): | |||
| def __init__(self, config): | |||
| super().__init__(config) | |||
| self.num_labels = config.num_labels | |||
| self.config = config | |||
| self.encoder = SbertModel(config, add_pooling_layer=True) | |||
| self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |||
| def forward(self, input_ids=None, token_type_ids=None): | |||
| outputs = self.encoder( | |||
| input_ids, | |||
| token_type_ids=token_type_ids, | |||
| return_dict=None, | |||
| ) | |||
| pooled_output = outputs[1] | |||
| pooled_output = self.dropout(pooled_output) | |||
| logits = self.classifier(pooled_output) | |||
| return logits | |||
| @MODELS.register_module( | |||
| Tasks.sentence_similarity, module_name=Models.structbert) | |||
| class SbertForSentenceSimilarity(Model): | |||
| class SbertForSentenceSimilarity(SbertForSequenceClassificationBase): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the sentence similarity model from the `model_dir` path. | |||
| @@ -50,39 +19,7 @@ class SbertForSentenceSimilarity(Model): | |||
| model_cls (Optional[Any], optional): model loader, if None, use the | |||
| default loader to load model weights, by default None. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| super().__init__( | |||
| model_dir, *args, model_args={'num_labels': 2}, **kwargs) | |||
| self.model_dir = model_dir | |||
| self.model = SbertTextClassifier.from_pretrained( | |||
| model_dir, num_labels=2) | |||
| self.model.eval() | |||
| self.label_path = os.path.join(self.model_dir, 'label_mapping.json') | |||
| with open(self.label_path) as f: | |||
| self.label_mapping = json.load(f) | |||
| self.id2label = {idx: name for name, idx in self.label_mapping.items()} | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Any]): the preprocessed data | |||
| Returns: | |||
| Dict[str, np.ndarray]: results | |||
| Example: | |||
| { | |||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||
| } | |||
| """ | |||
| input_ids = torch.tensor(input['input_ids'], dtype=torch.long) | |||
| token_type_ids = torch.tensor( | |||
| input['token_type_ids'], dtype=torch.long) | |||
| with torch.no_grad(): | |||
| logits = self.model(input_ids, token_type_ids) | |||
| probs = logits.softmax(-1).numpy() | |||
| pred = logits.argmax(-1).numpy() | |||
| logits = logits.numpy() | |||
| res = {'predictions': pred, 'probabilities': probs, 'logits': logits} | |||
| return res | |||
| assert self.model.config.num_labels == 2 | |||
| @@ -0,0 +1,22 @@ | |||
| from ...metainfo import Models | |||
| from ...utils.constant import Tasks | |||
| from ..builder import MODELS | |||
| from .sbert_for_sequence_classification import \ | |||
| SbertForSequenceClassificationBase | |||
| __all__ = ['SbertForSentimentClassification'] | |||
| @MODELS.register_module( | |||
| Tasks.sentiment_classification, module_name=Models.structbert) | |||
| class SbertForSentimentClassification(SbertForSequenceClassificationBase): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the text generation model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__( | |||
| model_dir, *args, model_args={'num_labels': 2}, **kwargs) | |||
| assert self.model.config.num_labels == 2 | |||
| @@ -0,0 +1,71 @@ | |||
| import os | |||
| from typing import Any, Dict | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| from sofa.models.sbert.modeling_sbert import SbertModel, SbertPreTrainedModel | |||
| from torch import nn | |||
| from ..base import Model | |||
| class SbertTextClassfier(SbertPreTrainedModel): | |||
| def __init__(self, config): | |||
| super().__init__(config) | |||
| self.num_labels = config.num_labels | |||
| self.config = config | |||
| self.encoder = SbertModel(config, add_pooling_layer=True) | |||
| self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |||
| def forward(self, input_ids=None, token_type_ids=None): | |||
| outputs = self.encoder( | |||
| input_ids, | |||
| token_type_ids=token_type_ids, | |||
| return_dict=None, | |||
| ) | |||
| pooled_output = outputs[1] | |||
| pooled_output = self.dropout(pooled_output) | |||
| logits = self.classifier(pooled_output) | |||
| return {'logits': logits} | |||
| class SbertForSequenceClassificationBase(Model): | |||
| def __init__(self, model_dir: str, model_args=None, *args, **kwargs): | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| if model_args is None: | |||
| model_args = {} | |||
| self.model = SbertTextClassfier.from_pretrained( | |||
| model_dir, **model_args) | |||
| self.id2label = {} | |||
| self.label_path = os.path.join(self.model_dir, 'label_mapping.json') | |||
| if os.path.exists(self.label_path): | |||
| with open(self.label_path) as f: | |||
| self.label_mapping = json.load(f) | |||
| self.id2label = { | |||
| idx: name | |||
| for name, idx in self.label_mapping.items() | |||
| } | |||
| def train(self): | |||
| return self.model.train() | |||
| def eval(self): | |||
| return self.model.eval() | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
| input_ids = torch.tensor(input['input_ids'], dtype=torch.long) | |||
| token_type_ids = torch.tensor( | |||
| input['token_type_ids'], dtype=torch.long) | |||
| return self.model.forward(input_ids, token_type_ids) | |||
| def postprocess(self, input, **kwargs): | |||
| logits = input['logits'] | |||
| probs = logits.softmax(-1).numpy() | |||
| pred = logits.argmax(-1).numpy() | |||
| logits = logits.numpy() | |||
| res = {'predictions': pred, 'probabilities': probs, 'logits': logits} | |||
| return res | |||
| @@ -2,18 +2,17 @@ from typing import Any, Dict, Union | |||
| import numpy as np | |||
| import torch | |||
| from sofa import SbertConfig, SbertForTokenClassification | |||
| from modelscope.metainfo import Models | |||
| from modelscope.utils.constant import Tasks | |||
| from ...metainfo import Models | |||
| from ...utils.constant import Tasks | |||
| from ..base import Model, Tensor | |||
| from ..builder import MODELS | |||
| __all__ = ['StructBertForTokenClassification'] | |||
| __all__ = ['SbertForTokenClassification'] | |||
| @MODELS.register_module(Tasks.word_segmentation, module_name=Models.structbert) | |||
| class StructBertForTokenClassification(Model): | |||
| class SbertForTokenClassification(Model): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the word segmentation model from the `model_dir` path. | |||
| @@ -25,9 +24,16 @@ class StructBertForTokenClassification(Model): | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self.model_dir = model_dir | |||
| self.model = SbertForTokenClassification.from_pretrained( | |||
| import sofa | |||
| self.model = sofa.SbertForTokenClassification.from_pretrained( | |||
| self.model_dir) | |||
| self.config = SbertConfig.from_pretrained(self.model_dir) | |||
| self.config = sofa.SbertConfig.from_pretrained(self.model_dir) | |||
| def train(self): | |||
| return self.model.train() | |||
| def eval(self): | |||
| return self.model.eval() | |||
| def forward(self, input: Dict[str, | |||
| Any]) -> Dict[str, Union[str, np.ndarray]]: | |||
| @@ -46,10 +52,12 @@ class StructBertForTokenClassification(Model): | |||
| } | |||
| """ | |||
| input_ids = torch.tensor(input['input_ids']).unsqueeze(0) | |||
| output = self.model(input_ids) | |||
| logits = output.logits | |||
| return {**self.model(input_ids), 'text': input['text']} | |||
| def postprocess(self, input: Dict[str, Tensor], | |||
| **kwargs) -> Dict[str, Tensor]: | |||
| logits = input['logits'] | |||
| pred = torch.argmax(logits[0], dim=-1) | |||
| pred = pred.numpy() | |||
| rst = {'predictions': pred, 'logits': logits, 'text': input['text']} | |||
| return rst | |||
| @@ -0,0 +1,81 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from typing import Any, Dict | |||
| from ....metainfo import Models | |||
| from ....preprocessors.space.fields.intent_field import IntentBPETextField | |||
| from ....trainers.nlp.space.trainer.intent_trainer import IntentTrainer | |||
| from ....utils.config import Config | |||
| from ....utils.constant import ModelFile, Tasks | |||
| from ...base import Model, Tensor | |||
| from ...builder import MODELS | |||
| from .model.generator import Generator | |||
| from .model.model_base import SpaceModelBase | |||
| __all__ = ['SpaceForDialogIntent'] | |||
| @MODELS.register_module( | |||
| Tasks.dialog_intent_prediction, module_name=Models.space) | |||
| class SpaceForDialogIntent(Model): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the test generation model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self.model_dir = model_dir | |||
| self.config = kwargs.pop( | |||
| 'config', | |||
| Config.from_file( | |||
| os.path.join(self.model_dir, ModelFile.CONFIGURATION))) | |||
| self.text_field = kwargs.pop( | |||
| 'text_field', | |||
| IntentBPETextField(self.model_dir, config=self.config)) | |||
| self.generator = Generator.create(self.config, reader=self.text_field) | |||
| self.model = SpaceModelBase.create( | |||
| model_dir=model_dir, | |||
| config=self.config, | |||
| reader=self.text_field, | |||
| generator=self.generator) | |||
| def to_tensor(array): | |||
| """ | |||
| numpy array -> tensor | |||
| """ | |||
| import torch | |||
| array = torch.tensor(array) | |||
| return array.cuda() if self.config.use_gpu else array | |||
| self.trainer = IntentTrainer( | |||
| model=self.model, | |||
| to_tensor=to_tensor, | |||
| config=self.config, | |||
| reader=self.text_field) | |||
| self.trainer.load() | |||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Any]): the preprocessed data | |||
| Returns: | |||
| Dict[str, np.ndarray]: results | |||
| Example: | |||
| { | |||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||
| } | |||
| """ | |||
| import numpy as np | |||
| pred = self.trainer.forward(input) | |||
| pred = np.squeeze(pred[0], 0) | |||
| return {'pred': pred} | |||
| @@ -0,0 +1,82 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from typing import Any, Dict, Optional | |||
| from ....metainfo import Models | |||
| from ....preprocessors.space.fields.gen_field import MultiWOZBPETextField | |||
| from ....trainers.nlp.space.trainer.gen_trainer import MultiWOZTrainer | |||
| from ....utils.config import Config | |||
| from ....utils.constant import ModelFile, Tasks | |||
| from ...base import Model, Tensor | |||
| from ...builder import MODELS | |||
| from .model.generator import Generator | |||
| from .model.model_base import SpaceModelBase | |||
| __all__ = ['SpaceForDialogModeling'] | |||
| @MODELS.register_module(Tasks.dialog_modeling, module_name=Models.space) | |||
| class SpaceForDialogModeling(Model): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the test generation model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self.model_dir = model_dir | |||
| self.config = kwargs.pop( | |||
| 'config', | |||
| Config.from_file( | |||
| os.path.join(self.model_dir, ModelFile.CONFIGURATION))) | |||
| self.text_field = kwargs.pop( | |||
| 'text_field', | |||
| MultiWOZBPETextField(self.model_dir, config=self.config)) | |||
| self.generator = Generator.create(self.config, reader=self.text_field) | |||
| self.model = SpaceModelBase.create( | |||
| model_dir=model_dir, | |||
| config=self.config, | |||
| reader=self.text_field, | |||
| generator=self.generator) | |||
| def to_tensor(array): | |||
| """ | |||
| numpy array -> tensor | |||
| """ | |||
| import torch | |||
| array = torch.tensor(array) | |||
| return array.cuda() if self.config.use_gpu else array | |||
| self.trainer = MultiWOZTrainer( | |||
| model=self.model, | |||
| to_tensor=to_tensor, | |||
| config=self.config, | |||
| reader=self.text_field, | |||
| evaluator=None) | |||
| self.trainer.load() | |||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Any]): the preprocessed data | |||
| Returns: | |||
| Dict[str, np.ndarray]: results | |||
| Example: | |||
| { | |||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||
| } | |||
| """ | |||
| turn = {'user': input['user']} | |||
| old_pv_turn = input['history'] | |||
| pv_turn = self.trainer.forward(turn=turn, old_pv_turn=old_pv_turn) | |||
| return pv_turn | |||
| @@ -0,0 +1,3 @@ | |||
| from .gen_unified_transformer import GenUnifiedTransformer | |||
| from .intent_unified_transformer import IntentUnifiedTransformer | |||
| from .unified_transformer import UnifiedTransformer | |||
| @@ -0,0 +1,283 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import torch | |||
| from .unified_transformer import UnifiedTransformer | |||
| class GenUnifiedTransformer(UnifiedTransformer): | |||
| """ | |||
| Implement generation unified transformer. | |||
| """ | |||
| def __init__(self, model_dir, config, reader, generator): | |||
| super(GenUnifiedTransformer, self).__init__(model_dir, config, reader, | |||
| generator) | |||
| self.understand = config.BPETextField.understand | |||
| if self.use_gpu: | |||
| self.cuda() | |||
| return | |||
| def _forward(self, inputs, is_training, with_label): | |||
| """ Real forward process of model in different mode(train/test). """ | |||
| def cat(x, y, dim=1): | |||
| return torch.cat([x, y], dim=dim) | |||
| outputs = {} | |||
| if self.understand or self.policy: | |||
| if self.understand: | |||
| prompt_token = inputs['understand_token'] | |||
| prompt_mask = inputs['understand_mask'] | |||
| if self.policy: | |||
| prompt_token = cat(prompt_token, inputs['policy_token']) | |||
| prompt_mask = cat(prompt_mask, inputs['policy_mask']) | |||
| else: | |||
| prompt_token = inputs['policy_token'] | |||
| prompt_mask = inputs['policy_mask'] | |||
| enc_embed, dec_embed, prompt_embed = self._encoder_prompt_decoder_network( | |||
| src_token=inputs['src_token'], | |||
| src_mask=inputs['src_mask'], | |||
| tgt_token=inputs['tgt_token'][:, :-1], | |||
| tgt_mask=inputs['tgt_mask'][:, :-1], | |||
| prompt_token=prompt_token, | |||
| prompt_mask=prompt_mask, | |||
| src_pos=inputs['src_pos'], | |||
| src_type=inputs['src_type'], | |||
| src_turn=inputs['src_turn'], | |||
| tgt_pos=inputs['tgt_pos'][:, :-1], | |||
| tgt_type=inputs['tgt_type'][:, :-1], | |||
| tgt_turn=inputs['tgt_turn'][:, :-1]) | |||
| else: | |||
| enc_embed, dec_embed = self._encoder_decoder_network( | |||
| src_token=inputs['src_token'], | |||
| src_mask=inputs['src_mask'], | |||
| tgt_token=inputs['tgt_token'][:, :-1], | |||
| tgt_mask=inputs['tgt_mask'][:, :-1], | |||
| src_pos=inputs['src_pos'], | |||
| src_type=inputs['src_type'], | |||
| src_turn=inputs['src_turn'], | |||
| tgt_pos=inputs['tgt_pos'][:, :-1], | |||
| tgt_type=inputs['tgt_type'][:, :-1], | |||
| tgt_turn=inputs['tgt_turn'][:, :-1]) | |||
| outputs['dec_probs'] = self._dec_head(dec_embed=dec_embed) | |||
| return outputs | |||
| def _collect_metrics(self, inputs, outputs, with_label, data_file): | |||
| metrics = {} | |||
| loss = 0. | |||
| label = inputs['tgt_token'][:, 1:] | |||
| token_num = torch.sum(torch.sum(inputs['tgt_mask'], dim=1) - 1) | |||
| nll = self.nll_loss( | |||
| torch.log(outputs['dec_probs'] + 1e-12).permute(0, 2, 1), label) | |||
| nll = torch.sum(nll, dim=1) | |||
| token_nll = torch.sum(nll) / token_num | |||
| nll = torch.mean(nll) | |||
| metrics['nll'] = nll | |||
| metrics['token_nll'] = token_nll | |||
| metrics['token_num'] = token_num | |||
| loss = loss + (token_nll if self.token_loss else nll) | |||
| metrics['loss'] = loss | |||
| if self.gpu > 1: | |||
| return nll, token_nll, token_num | |||
| else: | |||
| return metrics | |||
| def _optimize(self, loss, do_update=False, optimizer=None): | |||
| """ Optimize loss function and update model. """ | |||
| assert optimizer is not None | |||
| if self.gradient_accumulation_steps > 1: | |||
| loss = loss / self.gradient_accumulation_steps | |||
| loss.backward() | |||
| if self.grad_clip is not None and self.grad_clip > 0: | |||
| torch.nn.utils.clip_grad_norm_( | |||
| parameters=self.parameters(), max_norm=self.grad_clip) | |||
| if do_update: | |||
| optimizer.step() | |||
| optimizer.zero_grad() | |||
| return | |||
| def _init_state(self, | |||
| src_token, | |||
| src_mask, | |||
| src_pos=None, | |||
| src_type=None, | |||
| src_turn=None): | |||
| """ Initialize decode state. """ | |||
| state = {} | |||
| batch_size = src_token.shape[0] | |||
| src_embed = self.embedder(src_token, src_pos, src_type, src_turn) | |||
| src_embed = self.embed_layer_norm(src_embed) | |||
| mask = self._create_mask(src_mask, append_head=False) | |||
| enc_out = src_embed | |||
| cache = {} | |||
| for _l, layer in enumerate(self.layers): | |||
| cache[f'layer_{_l}'] = {} | |||
| enc_out = layer(enc_out, mask, cache[f'layer_{_l}']) | |||
| state['cache'] = cache | |||
| state['mask'] = mask[:, :1] | |||
| state['batch_size'] = batch_size | |||
| shape = [batch_size, 1, 1] | |||
| state['pred_mask'] = torch.ones(shape, dtype=torch.float32) | |||
| state['pred_pos'] = torch.zeros(shape, dtype=torch.int64) | |||
| state['pred_type'] = torch.zeros(shape, dtype=torch.int64) | |||
| state['pred_turn'] = torch.zeros(shape, dtype=torch.int64) | |||
| if self.use_gpu: | |||
| state['pred_mask'] = state['pred_mask'].cuda() | |||
| state['pred_pos'] = state['pred_pos'].cuda() | |||
| state['pred_type'] = state['pred_type'].cuda() | |||
| state['pred_turn'] = state['pred_turn'].cuda() | |||
| return state | |||
| def _init_prompt_state(self, | |||
| src_token, | |||
| src_mask, | |||
| prompt_token, | |||
| prompt_mask, | |||
| src_pos=None, | |||
| src_type=None, | |||
| src_turn=None, | |||
| prompt_pos=None, | |||
| prompt_type=None, | |||
| prompt_turn=None): | |||
| """ Initialize decode state. """ | |||
| state = {} | |||
| batch_size = src_token.shape[0] | |||
| src_embed = self.embedder(src_token, src_pos, src_type, src_turn) | |||
| prompt_embed = self.embedder(prompt_token, prompt_pos, prompt_type, | |||
| prompt_turn) | |||
| embed = torch.cat([src_embed, prompt_embed], dim=1) | |||
| embed = self.embed_layer_norm(embed) | |||
| enc_out = embed | |||
| enc_mask = self._create_mask(src_mask, auto_regressive=False) | |||
| dec_mask = self._create_mask(prompt_mask, auto_regressive=True) | |||
| mask = self._join_mask(enc_mask, dec_mask) | |||
| cache = {} | |||
| for _l, layer in enumerate(self.layers): | |||
| cache[f'layer_{_l}'] = {} | |||
| enc_out = layer(enc_out, mask, cache[f'layer_{_l}']) | |||
| state['cache'] = cache | |||
| state['mask'] = mask[:, -1:] # state["mask"] = mask[:, :1] | |||
| state['batch_size'] = batch_size | |||
| shape = [batch_size, 1, 1] | |||
| state['pred_mask'] = torch.ones(shape, dtype=torch.float32) | |||
| state['pred_pos'] = torch.zeros(shape, dtype=torch.int64) | |||
| state['pred_type'] = torch.zeros(shape, dtype=torch.int64) | |||
| state['pred_turn'] = torch.zeros(shape, dtype=torch.int64) | |||
| if self.use_gpu: | |||
| state['pred_mask'] = state['pred_mask'].cuda() | |||
| state['pred_pos'] = state['pred_pos'].cuda() | |||
| state['pred_type'] = state['pred_type'].cuda() | |||
| state['pred_turn'] = state['pred_turn'].cuda() | |||
| return state | |||
| def _decode(self, state): | |||
| """ Decoding one time stamp. """ | |||
| # shape: [batch_size, 1, seq_len] | |||
| mask = state['mask'] | |||
| # shape: [batch_size, 1, 1] | |||
| pred_token = state['pred_token'] | |||
| pred_mask = state['pred_mask'] | |||
| pred_pos = state['pred_pos'] | |||
| pred_type = state['pred_type'] | |||
| pred_turn = state['pred_turn'] | |||
| # list of shape(len: num_layers): [batch_size, seq_len, hidden_dim] | |||
| cache = state['cache'] | |||
| pred_embed = self.embedder(pred_token, pred_pos, pred_type, | |||
| pred_turn).squeeze(-2) | |||
| pred_embed = self.embed_layer_norm(pred_embed) | |||
| # shape: [batch_size, 1, seq_len + 1] | |||
| mask = torch.cat([mask, 1 - pred_mask], dim=2) | |||
| # shape: [batch_size, 1, hidden_dim] | |||
| for _l, layer in enumerate(self.layers): | |||
| pred_embed = layer(pred_embed, mask, cache[f'layer_{_l}']) | |||
| # shape: [batch_size, vocab_size] | |||
| pred_probs = self._dec_head(dec_embed=pred_embed[:, 0]) | |||
| pred_logits = torch.log(pred_probs) | |||
| state['mask'] = mask | |||
| return pred_logits, state | |||
| def _infer(self, | |||
| inputs, | |||
| start_id=None, | |||
| eos_id=None, | |||
| max_gen_len=None, | |||
| prev_input=None): | |||
| """ Real inference process of model. """ | |||
| def cat(x, y, dim=1): | |||
| return torch.cat([x, y], dim=dim) | |||
| # Initial decode state. | |||
| if self.understand or self.policy: | |||
| if self.understand: | |||
| prompt_token = inputs['understand_token'] | |||
| prompt_mask = inputs['understand_mask'] | |||
| if self.policy: | |||
| prompt_token = cat(prompt_token, inputs['policy_token']) | |||
| prompt_mask = cat(prompt_mask, inputs['policy_mask']) | |||
| else: | |||
| prompt_token = inputs['policy_token'] | |||
| prompt_mask = inputs['policy_mask'] | |||
| state = self._init_prompt_state( | |||
| src_token=inputs['src_token'], | |||
| src_mask=inputs['src_mask'], | |||
| prompt_token=prompt_token, | |||
| prompt_mask=prompt_mask, | |||
| src_pos=inputs['src_pos'], | |||
| src_type=inputs['src_type'], | |||
| src_turn=inputs['src_turn']) | |||
| else: | |||
| state = self._init_state( | |||
| src_token=inputs['src_token'], | |||
| src_mask=inputs['src_mask'], | |||
| src_pos=inputs['src_pos'], | |||
| src_type=inputs['src_type'], | |||
| src_turn=inputs['src_turn']) | |||
| # Generation process. | |||
| gen_results = self.generator( | |||
| step_fn=self._decode, | |||
| state=state, | |||
| start_id=start_id, | |||
| eos_id=eos_id, | |||
| max_gen_len=max_gen_len, | |||
| prev_input=prev_input) | |||
| outputs = gen_results['preds'] | |||
| return outputs | |||
| GenUnifiedTransformer.register('GenUnifiedTransformer') | |||
| @@ -0,0 +1,287 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import math | |||
| import numpy as np | |||
| import torch | |||
| def repeat(var, times): | |||
| if isinstance(var, list): | |||
| return [repeat(x, times) for x in var] | |||
| elif isinstance(var, dict): | |||
| return {k: repeat(v, times) for k, v in var.items()} | |||
| elif isinstance(var, torch.Tensor): | |||
| var = var.unsqueeze(1) | |||
| expand_times = [1] * len(var.shape) | |||
| expand_times[1] = times | |||
| dtype = var.dtype | |||
| var = var.float() | |||
| var = var.repeat(*expand_times) | |||
| shape = [var.shape[0] * var.shape[1]] + list(var.shape[2:]) | |||
| var = var.reshape(*shape) | |||
| var = torch.tensor(var, dtype=dtype) | |||
| return var | |||
| else: | |||
| return var | |||
| def gather(var, idx): | |||
| if isinstance(var, list): | |||
| return [gather(x, idx) for x in var] | |||
| elif isinstance(var, dict): | |||
| return {k: gather(v, idx) for k, v in var.items()} | |||
| elif isinstance(var, torch.Tensor): | |||
| out = var.index_select(dim=0, index=idx) | |||
| return out | |||
| else: | |||
| return var | |||
| class Generator(object): | |||
| """ Genrator class. """ | |||
| _registry = dict() | |||
| @classmethod | |||
| def register(cls, name): | |||
| Generator._registry[name] = cls | |||
| return | |||
| @staticmethod | |||
| def by_name(name): | |||
| return Generator._registry[name] | |||
| @staticmethod | |||
| def create(config, *args, **kwargs): | |||
| """ Create generator. """ | |||
| generator_cls = Generator.by_name(config.Generator.generator) | |||
| return generator_cls(config, *args, **kwargs) | |||
| def __init__(self, config, reader): | |||
| self.vocab_size = reader.vocab_size | |||
| self.bos_id = reader.bos_id | |||
| self.eos_id = reader.eos_id | |||
| self.unk_id = reader.unk_id | |||
| self.pad_id = reader.pad_id | |||
| self.min_gen_len = config.Generator.min_gen_len | |||
| self.max_gen_len = config.Generator.max_gen_len | |||
| self.use_gpu = config.use_gpu | |||
| assert 1 <= self.min_gen_len <= self.max_gen_len | |||
| return | |||
| def __call__(self, step_fn, state): | |||
| """ | |||
| Running generation. | |||
| @param : step_fn : decoding one step | |||
| @type : function | |||
| @param : state : initial state | |||
| @type : dict | |||
| """ | |||
| raise NotImplementedError | |||
| class BeamSearch(Generator): | |||
| """ BeamSearch generator. """ | |||
| def __init__(self, config, reader): | |||
| super().__init__(config, reader) | |||
| self.beam_size = config.Generator.beam_size | |||
| self.length_average = config.Generator.length_average | |||
| self.length_penalty = config.Generator.length_penalty | |||
| self.ignore_unk = config.Generator.ignore_unk | |||
| return | |||
| def __call__(self, | |||
| step_fn, | |||
| state, | |||
| start_id=None, | |||
| eos_id=None, | |||
| max_gen_len=None, | |||
| prev_input=None): | |||
| """ | |||
| Running beam search. | |||
| @param : step_fn : decoding one step | |||
| @type : function | |||
| @param : state : initial state | |||
| @type : dict | |||
| """ | |||
| if prev_input is not None: | |||
| if isinstance(prev_input, list): | |||
| length = max(list(map(lambda x: len(x), prev_input))) | |||
| prev_input_numpy = np.full((len(prev_input), length), | |||
| self.pad_id) | |||
| for i, x in enumerate(prev_input): | |||
| prev_input_numpy[i, :len(x)] = x | |||
| prev_input_tensor = torch.from_numpy(prev_input_numpy) | |||
| if self.use_gpu: | |||
| prev_input_tensor = prev_input_tensor.cuda() | |||
| for i in range(length): | |||
| state['pred_token'] = prev_input_tensor[:, i].unsqueeze( | |||
| -1).unsqueeze(-1) | |||
| if i != 0: | |||
| state['pred_mask'] = torch.not_equal( | |||
| state['pred_token'], self.pad_id).float() | |||
| state['pred_pos'] = state['pred_pos'] + state[ | |||
| 'pred_mask'].int() | |||
| _, state = step_fn(state) | |||
| else: | |||
| assert isinstance(prev_input, torch.Tensor) | |||
| for i, input in enumerate(prev_input): | |||
| state['pred_token'] = input.expand(1, 1, 1) | |||
| if i != 0: | |||
| state['pred_mask'] = torch.not_equal( | |||
| state['pred_token'], self.pad_id).float() | |||
| state['pred_pos'] = state['pred_pos'] + 1 | |||
| _, state = step_fn(state) | |||
| batch_size = state['batch_size'] | |||
| beam_size = self.beam_size | |||
| # shape: [batch_size, 1] | |||
| pos_index = torch.arange( | |||
| 0, batch_size, 1, dtype=torch.int64) * beam_size | |||
| pos_index = pos_index.unsqueeze(1) | |||
| # shape: [batch_size, beam_size, 1] | |||
| if start_id is None: | |||
| start_id = self.bos_id | |||
| if eos_id is None: | |||
| eos_id = self.eos_id | |||
| predictions = torch.ones([batch_size, beam_size, 1], | |||
| dtype=torch.int64) * start_id | |||
| if self.use_gpu: | |||
| pos_index = pos_index.cuda() | |||
| predictions = predictions.cuda() | |||
| # initial input (start_id) | |||
| state['pred_token'] = predictions[:, :1] | |||
| if prev_input is not None: | |||
| state['pred_mask'] = torch.not_equal(state['pred_token'], | |||
| self.pad_id).float() | |||
| state['pred_pos'] = state['pred_pos'] + 1 | |||
| # shape: [batch_size, vocab_size] | |||
| scores, state = step_fn(state) | |||
| unk_penalty = np.zeros(self.vocab_size, dtype='float32') | |||
| unk_penalty[self.unk_id] = -1e10 | |||
| unk_penalty = torch.from_numpy(unk_penalty) | |||
| eos_penalty = np.zeros(self.vocab_size, dtype='float32') | |||
| eos_penalty[eos_id] = -1e10 | |||
| eos_penalty = torch.from_numpy(eos_penalty) | |||
| scores_after_end = np.full(self.vocab_size, -1e10, dtype='float32') | |||
| scores_after_end[ | |||
| self. | |||
| pad_id] = 0 # we want <pad> is generated after <eos>,so maximum log(p(<pad>)) is (0) | |||
| scores_after_end = torch.from_numpy(scores_after_end) | |||
| if self.use_gpu: | |||
| unk_penalty = unk_penalty.cuda() | |||
| eos_penalty = eos_penalty.cuda() | |||
| scores_after_end = scores_after_end.cuda() | |||
| if self.ignore_unk: | |||
| scores = scores + unk_penalty | |||
| scores = scores + eos_penalty | |||
| # shape: [batch_size, beam_size] | |||
| sequence_scores, preds = torch.topk(scores, self.beam_size) | |||
| predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2) | |||
| state = repeat(state, beam_size) | |||
| if max_gen_len is None: | |||
| max_gen_len = self.max_gen_len | |||
| for step in range(2, max_gen_len + 1): | |||
| pre_ids = predictions[:, :, -1:] | |||
| state['pred_token'] = pre_ids.reshape(batch_size * beam_size, 1, 1) | |||
| state['pred_mask'] = torch.not_equal(state['pred_token'], | |||
| self.pad_id).float() | |||
| state['pred_pos'] = state['pred_pos'] + 1 | |||
| scores, state = step_fn(state) | |||
| # Generate next | |||
| # scores shape: [batch_size * beam_size, vocab_size] | |||
| if self.ignore_unk: | |||
| scores = scores + unk_penalty | |||
| if step <= self.min_gen_len: | |||
| scores = scores + eos_penalty | |||
| # scores shape: [batch_size, beam_size, vocab_size] | |||
| scores = scores.reshape(batch_size, beam_size, self.vocab_size) | |||
| # previous token is [PAD] or [EOS] | |||
| pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \ | |||
| (1 - torch.not_equal(pre_ids, self.pad_id).float()) | |||
| scores = scores * (1 - pre_eos_mask) + pre_eos_mask.repeat( | |||
| 1, 1, self.vocab_size) * scores_after_end | |||
| if self.length_average: | |||
| scaled_value = \ | |||
| pre_eos_mask + (1 - pre_eos_mask) * (1 - 1 / step) | |||
| sequence_scores = sequence_scores.unsqueeze(2) * scaled_value | |||
| scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 / step) | |||
| scores = scores * scaled_value | |||
| elif self.length_penalty >= 0.0: | |||
| scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \ | |||
| (math.pow((4 + step) / (5 + step), self.length_penalty)) | |||
| sequence_scores = scaled_value * sequence_scores | |||
| scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \ | |||
| (math.pow(1 / (5 + step), self.length_penalty)) | |||
| scores = scores * scaled_value | |||
| scores = scores + sequence_scores.unsqueeze(-1) | |||
| scores = scores.reshape(batch_size, beam_size * self.vocab_size) | |||
| topk_scores, topk_indices = torch.topk(scores, beam_size) | |||
| # topk_indices: [batch_size, beam_size * self.vocab_size] (already reshaped) | |||
| parent_idx = topk_indices.floor_divide(self.vocab_size) | |||
| preds = topk_indices % self.vocab_size | |||
| # Gather state / sequence_scores | |||
| parent_idx = parent_idx + pos_index | |||
| parent_idx = parent_idx.reshape(batch_size * beam_size) | |||
| state = gather(state, parent_idx) | |||
| sequence_scores = topk_scores | |||
| predictions = predictions.reshape(batch_size * beam_size, step) | |||
| predictions = gather(predictions, parent_idx) | |||
| predictions = predictions.reshape(batch_size, beam_size, step) | |||
| predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2) | |||
| # The last token should be <eos> or <pad> | |||
| pre_ids = predictions[:, :, -1] | |||
| pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \ | |||
| (1 - torch.not_equal(pre_ids, self.pad_id).float()) | |||
| sequence_scores = sequence_scores * pre_eos_mask + ( | |||
| 1 - pre_eos_mask) * (-1e10) | |||
| # first get ascending ordered index,then sort "predictions" and "sequence_scores" | |||
| indices = torch.argsort(sequence_scores, dim=1) | |||
| indices = indices + pos_index | |||
| indices = indices.reshape(-1) | |||
| sequence_scores = sequence_scores.reshape(batch_size * beam_size) | |||
| predictions = predictions.reshape(batch_size * beam_size, -1) | |||
| sequence_scores = gather(sequence_scores, indices) | |||
| predictions = gather(predictions, indices) | |||
| sequence_scores = sequence_scores.reshape(batch_size, beam_size) | |||
| predictions = predictions.reshape(batch_size, beam_size, -1) | |||
| results = { | |||
| 'preds': predictions[:, -1], | |||
| 'scores': sequence_scores[:, -1] | |||
| } | |||
| return results | |||
| BeamSearch.register('BeamSearch') | |||
| @@ -0,0 +1,197 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from .....utils.nlp.space.criterions import compute_kl_loss | |||
| from .unified_transformer import UnifiedTransformer | |||
| class IntentUnifiedTransformer(UnifiedTransformer): | |||
| """ | |||
| Implement intent unified transformer. | |||
| """ | |||
| def __init__(self, model_dir, config, reader, generator): | |||
| super(IntentUnifiedTransformer, self).__init__(model_dir, config, | |||
| reader, generator) | |||
| self.example = config.Model.example | |||
| self.num_intent = config.Model.num_intent | |||
| self.with_rdrop = config.Model.with_rdrop | |||
| self.kl_ratio = config.Model.kl_ratio | |||
| self.loss_fct = nn.CrossEntropyLoss() | |||
| if self.example: | |||
| self.loss_fct = nn.NLLLoss() | |||
| else: | |||
| self.intent_classifier = nn.Linear(self.hidden_dim, | |||
| self.num_intent) | |||
| self.loss_fct = nn.CrossEntropyLoss() | |||
| if self.use_gpu: | |||
| self.cuda() | |||
| return | |||
| def _forward(self, inputs, is_training, with_label): | |||
| """ Real forward process of model in different mode(train/test). """ | |||
| def aug(v): | |||
| assert isinstance(v, torch.Tensor) | |||
| return torch.cat([v, v], dim=0) | |||
| outputs = {} | |||
| if self.with_mlm: | |||
| mlm_embed = self._encoder_network( | |||
| input_token=inputs['mlm_token'], | |||
| input_mask=inputs['src_mask'], | |||
| input_pos=inputs['src_pos'], | |||
| input_type=inputs['src_type'], | |||
| input_turn=inputs['src_turn']) | |||
| outputs['mlm_probs'] = self._mlm_head(mlm_embed=mlm_embed) | |||
| if self.with_rdrop or self.with_contrastive: | |||
| enc_embed, dec_embed = self._encoder_decoder_network( | |||
| src_token=aug(inputs['src_token']), | |||
| src_mask=aug(inputs['src_mask']), | |||
| tgt_token=aug(inputs['tgt_token']), | |||
| tgt_mask=aug(inputs['tgt_mask']), | |||
| src_pos=aug(inputs['src_pos']), | |||
| src_type=aug(inputs['src_type']), | |||
| src_turn=aug(inputs['src_turn'])) | |||
| else: | |||
| enc_embed, dec_embed = self._encoder_decoder_network( | |||
| src_token=inputs['src_token'], | |||
| src_mask=inputs['src_mask'], | |||
| tgt_token=inputs['tgt_token'], | |||
| tgt_mask=inputs['tgt_mask'], | |||
| src_pos=inputs['src_pos'], | |||
| src_type=inputs['src_type'], | |||
| src_turn=inputs['src_turn']) | |||
| features = dec_embed[:, -1] | |||
| features = self.pooler(features) if self.with_pool else features | |||
| if self.example: | |||
| assert not self.with_rdrop | |||
| ex_enc_embed, ex_dec_embed = self._encoder_decoder_network( | |||
| src_token=inputs['example_src_token'], | |||
| src_mask=inputs['example_src_mask'], | |||
| tgt_token=inputs['example_tgt_token'], | |||
| tgt_mask=inputs['example_tgt_mask'], | |||
| src_pos=inputs['example_src_pos'], | |||
| src_type=inputs['example_src_type'], | |||
| src_turn=inputs['example_src_turn']) | |||
| ex_features = ex_dec_embed[:, -1] | |||
| ex_features = self.pooler( | |||
| ex_features) if self.with_pool else ex_features | |||
| probs = self.softmax(features.mm(ex_features.t())) | |||
| example_intent = inputs['example_intent'].unsqueeze(0) | |||
| intent_probs = torch.zeros(probs.size(0), self.num_intent) | |||
| intent_probs = intent_probs.cuda( | |||
| ) if self.use_gpu else intent_probs | |||
| intent_probs = intent_probs.scatter_add( | |||
| -1, example_intent.repeat(probs.size(0), 1), probs) | |||
| outputs['intent_probs'] = intent_probs | |||
| else: | |||
| intent_logits = self.intent_classifier(features) | |||
| outputs['intent_logits'] = intent_logits | |||
| if self.with_contrastive: | |||
| features = features if self.with_pool else self.pooler(features) | |||
| batch_size = features.size(0) // 2 | |||
| features = \ | |||
| torch.cat( | |||
| [features[:batch_size].unsqueeze(1), features[batch_size:].unsqueeze(1)], | |||
| dim=1 | |||
| ) | |||
| features = F.normalize(features, dim=-1, p=2) | |||
| outputs['features'] = features | |||
| return outputs | |||
| def _collect_metrics(self, inputs, outputs, with_label, data_file): | |||
| metrics = {} | |||
| batch_size = inputs['src_token'].size(0) | |||
| intent_label = torch.cat([inputs['intent_label'], inputs['intent_label']], dim=0) \ | |||
| if self.with_rdrop or self.with_contrastive else inputs['intent_label'] | |||
| if self.example: | |||
| intent_loss = self.loss_fct( | |||
| torch.log(outputs['intent_probs'] + 1e-12).view( | |||
| -1, self.num_intent), intent_label.type(torch.long)) | |||
| else: | |||
| intent_loss = self.loss_fct( | |||
| outputs['intent_logits'].view(-1, self.num_intent), | |||
| intent_label.type(torch.long)) | |||
| metrics['intent_loss'] = intent_loss | |||
| loss = intent_loss | |||
| if self.with_mlm: | |||
| mlm_num = torch.sum(torch.sum(inputs['mlm_mask'], dim=1)) | |||
| mlm = self.nll_loss( | |||
| torch.log(outputs['mlm_probs'] + 1e-12).permute(0, 2, 1), | |||
| inputs['mlm_label']) | |||
| mlm = torch.sum(mlm, dim=1) | |||
| token_mlm = torch.sum(mlm) / mlm_num | |||
| mlm = torch.mean(mlm) | |||
| metrics['mlm'] = mlm | |||
| metrics['token_mlm'] = token_mlm | |||
| metrics['mlm_num'] = mlm_num | |||
| loss = loss + (token_mlm | |||
| if self.token_loss else mlm) * self.mlm_ratio | |||
| else: | |||
| mlm, token_mlm, mlm_num = None, None, None | |||
| if self.with_rdrop: | |||
| kl = compute_kl_loss( | |||
| p=outputs['intent_logits'][:batch_size], | |||
| q=outputs['intent_logits'][batch_size:]) | |||
| metrics['kl'] = kl | |||
| loss = loss + kl * self.kl_ratio | |||
| else: | |||
| kl = None | |||
| if self.with_contrastive: | |||
| pass | |||
| con = None | |||
| else: | |||
| con = None | |||
| metrics['loss'] = loss | |||
| if self.gpu > 1: | |||
| return intent_loss, mlm, token_mlm, mlm_num, kl, con | |||
| else: | |||
| return metrics | |||
| def _infer(self, | |||
| inputs, | |||
| start_id=None, | |||
| eos_id=None, | |||
| max_gen_len=None, | |||
| prev_input=None): | |||
| """ Real inference process of model. """ | |||
| results = {} | |||
| enc_embed, dec_embed = self._encoder_decoder_network( | |||
| src_token=inputs['src_token'], | |||
| src_mask=inputs['src_mask'], | |||
| tgt_token=inputs['tgt_token'], | |||
| tgt_mask=inputs['tgt_mask'], | |||
| src_pos=inputs['src_pos'], | |||
| src_type=inputs['src_type'], | |||
| src_turn=inputs['src_turn']) | |||
| features = dec_embed[:, -1] | |||
| features = self.pooler(features) if self.with_pool else features | |||
| if self.example: | |||
| results['features'] = features | |||
| else: | |||
| intent_logits = self.intent_classifier(features) | |||
| intent_probs = self.softmax(intent_logits) | |||
| results['intent_probs'] = intent_probs | |||
| return results | |||
| IntentUnifiedTransformer.register('IntentUnifiedTransformer') | |||
| @@ -0,0 +1,101 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import torch.nn as nn | |||
| from .....utils.constant import ModelFile | |||
| class SpaceModelBase(nn.Module): | |||
| """ | |||
| Basic model wrapper for static graph and dygrpah. | |||
| """ | |||
| _registry = dict() | |||
| @classmethod | |||
| def register(cls, name): | |||
| SpaceModelBase._registry[name] = cls | |||
| return | |||
| @staticmethod | |||
| def by_name(name): | |||
| return SpaceModelBase._registry[name] | |||
| @staticmethod | |||
| def create(model_dir, config, *args, **kwargs): | |||
| model_cls = SpaceModelBase.by_name(config.Model.model) | |||
| return model_cls(model_dir, config, *args, **kwargs) | |||
| def __init__(self, model_dir, config): | |||
| super(SpaceModelBase, self).__init__() | |||
| self.init_checkpoint = os.path.join(model_dir, | |||
| ModelFile.TORCH_MODEL_BIN_FILE) | |||
| self.abandon_label = config.Dataset.abandon_label | |||
| self.use_gpu = config.use_gpu | |||
| self.gpu = config.Trainer.gpu | |||
| return | |||
| def _create_parameters(self): | |||
| """ Create model's paramters. """ | |||
| raise NotImplementedError | |||
| def _forward(self, inputs, is_training, with_label): | |||
| """ NO LABEL: Real forward process of model in different mode(train/test). """ | |||
| raise NotImplementedError | |||
| def _collect_metrics(self, inputs, outputs, with_label, data_file): | |||
| """ NO LABEL: Calculate loss function by using inputs and outputs. """ | |||
| raise NotImplementedError | |||
| def _optimize(self, loss, optimizer, lr_scheduler): | |||
| """ Optimize loss function and update model. """ | |||
| raise NotImplementedError | |||
| def _infer(self, inputs, start_id, eos_id, max_gen_len, prev_input): | |||
| """ Real inference process of model. """ | |||
| raise NotImplementedError | |||
| def forward(self, | |||
| inputs, | |||
| is_training=False, | |||
| with_label=False, | |||
| data_file=None): | |||
| """ | |||
| Forward process, include real forward, collect metrices and optimize(optional) | |||
| @params : inputs : input data | |||
| @type : dict of numpy.ndarray/int/float/... | |||
| """ | |||
| if is_training: | |||
| self.train() | |||
| else: | |||
| self.eval() | |||
| with_label = False if self.abandon_label else with_label | |||
| outputs = self._forward(inputs, is_training, with_label=with_label) | |||
| metrics = self._collect_metrics( | |||
| inputs, outputs, with_label=with_label, data_file=data_file) | |||
| return metrics | |||
| def infer(self, | |||
| inputs, | |||
| start_id=None, | |||
| eos_id=None, | |||
| max_gen_len=None, | |||
| prev_input=None): | |||
| """ | |||
| Inference process. | |||
| @params : inputs : input data | |||
| @type : dict of numpy.ndarray/int/float/... | |||
| """ | |||
| self.eval() | |||
| results = self._infer( | |||
| inputs, | |||
| start_id=start_id, | |||
| eos_id=eos_id, | |||
| max_gen_len=max_gen_len, | |||
| prev_input=prev_input) | |||
| return results | |||
| @@ -0,0 +1,313 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from ..modules.embedder import Embedder | |||
| from ..modules.transformer_block import TransformerBlock | |||
| from .model_base import SpaceModelBase | |||
| class UnifiedTransformer(SpaceModelBase): | |||
| """ | |||
| Implement unified transformer. | |||
| """ | |||
| def __init__(self, model_dir, config, reader, generator, dtype='float32'): | |||
| super(UnifiedTransformer, self).__init__(model_dir, config) | |||
| self.reader = reader | |||
| self.generator = generator | |||
| self.policy = config.BPETextField.policy | |||
| self.generation = config.BPETextField.generation | |||
| self.num_token_embeddings = config.Model.num_token_embeddings | |||
| self.num_pos_embeddings = config.Model.num_pos_embeddings | |||
| self.num_type_embeddings = config.Model.num_type_embeddings | |||
| self.num_turn_embeddings = config.Model.num_turn_embeddings | |||
| self.temperature = config.Model.temperature | |||
| self.hidden_dim = config.Model.hidden_dim | |||
| self.num_heads = config.Model.num_heads | |||
| self.num_layers = config.Model.num_layers | |||
| self.padding_idx = config.Model.padding_idx | |||
| self.dropout = config.Model.dropout | |||
| self.embed_dropout = config.Model.embed_dropout | |||
| self.attn_dropout = config.Model.attn_dropout | |||
| self.ff_dropout = config.Model.ff_dropout | |||
| self.mlm_ratio = config.Model.mlm_ratio | |||
| self.mmd_ratio = config.Model.mmd_ratio | |||
| self.pos_trainable = config.Model.pos_trainable | |||
| self.label_smooth = config.Model.label_smooth | |||
| self.initializer_range = config.Model.initializer_range | |||
| self.gradient_accumulation_steps = config.Model.gradient_accumulation_steps | |||
| self.token_loss = config.Trainer.token_loss | |||
| self.learning_method = config.Dataset.learning_method | |||
| self.with_contrastive = config.Dataset.with_contrastive | |||
| self.with_query_bow = config.BPETextField.with_query_bow | |||
| self.with_resp_bow = config.BPETextField.with_resp_bow | |||
| self.with_pool = config.Model.with_pool | |||
| self.with_mlm = config.Dataset.with_mlm | |||
| self._dtype = dtype | |||
| self.embedder = Embedder( | |||
| self.hidden_dim, | |||
| self.num_token_embeddings, | |||
| self.num_pos_embeddings, | |||
| self.num_type_embeddings, | |||
| self.num_turn_embeddings, | |||
| padding_idx=self.padding_idx, | |||
| dropout=self.embed_dropout, | |||
| pos_trainable=self.pos_trainable) | |||
| self.embed_layer_norm = nn.LayerNorm( | |||
| normalized_shape=self.hidden_dim, | |||
| eps=1e-12, | |||
| elementwise_affine=True) | |||
| self.layers = nn.ModuleList([ | |||
| TransformerBlock(self.hidden_dim, self.num_heads, self.dropout, | |||
| self.attn_dropout, self.ff_dropout) | |||
| for _ in range(config.Model.num_layers) | |||
| ]) | |||
| if self.with_mlm: | |||
| self.mlm_transform = nn.Sequential( | |||
| nn.Linear(self.hidden_dim, self.hidden_dim), nn.GELU(), | |||
| nn.LayerNorm( | |||
| normalized_shape=self.hidden_dim, | |||
| eps=1e-12, | |||
| elementwise_affine=True)) | |||
| self.mlm_bias = nn.Parameter( | |||
| torch.zeros(self.num_token_embeddings)) | |||
| self.pooler = nn.Sequential( | |||
| nn.Linear(self.hidden_dim, self.hidden_dim), nn.Tanh()) | |||
| if self.with_query_bow or self.with_resp_bow: | |||
| self.bow_predictor = nn.Linear( | |||
| self.hidden_dim, self.num_token_embeddings, bias=False) | |||
| self.sigmoid = nn.Sigmoid() | |||
| self.softmax = nn.Softmax(dim=-1) | |||
| self.bce_loss = nn.BCELoss(reduction='none') | |||
| self.nll_loss = nn.NLLLoss( | |||
| ignore_index=self.padding_idx, reduction='none') | |||
| self._create_parameters() | |||
| self.max_grad_norm = config.Model.max_grad_norm | |||
| if self.max_grad_norm is not None: | |||
| self.grad_clip = self.max_grad_norm | |||
| else: | |||
| self.grad_clip = None | |||
| self.weight_decay = config.Model.weight_decay | |||
| if self.use_gpu: | |||
| self.cuda() | |||
| return | |||
| def _create_parameters(self): | |||
| """ Create model's paramters. """ | |||
| sequence_mask = np.tri( | |||
| self.num_pos_embeddings, | |||
| self.num_pos_embeddings, | |||
| dtype=self._dtype) | |||
| self.sequence_mask = torch.tensor(sequence_mask) | |||
| return | |||
| def _create_mask(self, | |||
| input_mask, | |||
| append_head=False, | |||
| auto_regressive=False): | |||
| """ | |||
| Create attention mask. | |||
| from sequence to matrix:[batch_size, max_seq_len, 1] -> [batch_size, max_seq_len, max_seq_len] | |||
| @param : input_mask | |||
| @type : Variable(shape: [batch_size, max_seq_len]) | |||
| @param : auto_regressive | |||
| @type : bool | |||
| """ | |||
| seq_len = input_mask.shape[1] | |||
| input_mask = input_mask.float() | |||
| mask1 = input_mask.unsqueeze(-1).repeat(1, 1, seq_len) | |||
| mask2 = mask1.permute(0, 2, 1) | |||
| mask = mask1 * mask2 | |||
| if append_head: | |||
| mask = torch.cat([mask[:, :1, :], mask], dim=1) | |||
| mask = torch.cat([mask[:, :, :1], mask], dim=2) | |||
| seq_len += 1 | |||
| if auto_regressive: | |||
| seq_mask = self.sequence_mask[:seq_len, :seq_len] | |||
| seq_mask = seq_mask.to(mask.device) | |||
| mask = mask * seq_mask | |||
| mask = 1 - mask | |||
| return mask | |||
| def _join_mask(self, mask1, mask2): | |||
| """ | |||
| Merge source attention mask and target attention mask. | |||
| There are four parts:left upper (lu) / right upper (ru) / left below (lb) / right below (rb) | |||
| @param : mask1 : source attention mask | |||
| @type : Variable(shape: [batch_size, max_src_len, max_src_len]) | |||
| @param : mask1 : target attention mask | |||
| @type : Variable(shape: [batch_size, max_tgt_len, max_tgt_len]) | |||
| """ | |||
| batch_size = mask1.shape[0] | |||
| seq_len1 = mask1.shape[1] | |||
| seq_len2 = mask2.shape[1] | |||
| # seq_len = seq_len1 + seq_len2 | |||
| mask_lu = mask1 | |||
| mask_ru = torch.ones(batch_size, seq_len1, seq_len2) | |||
| if self.use_gpu: | |||
| mask_ru = mask_ru.cuda() | |||
| mask3 = mask2[:, :, :1].repeat(1, 1, seq_len1) | |||
| mask4 = mask1[:, :1].repeat(1, seq_len2, 1) | |||
| mask_lb = mask3 + mask4 - mask3 * mask4 | |||
| mask_rb = mask2 | |||
| mask_u = torch.cat([mask_lu, mask_ru], dim=2) | |||
| mask_b = torch.cat([mask_lb, mask_rb], dim=2) | |||
| mask = torch.cat([mask_u, mask_b], dim=1) | |||
| return mask | |||
| def _mlm_head(self, mlm_embed): | |||
| mlm_embed = self.mlm_transform(mlm_embed) | |||
| mlm_logits = torch.matmul( | |||
| mlm_embed, self.embedder.token_embedding.weight.T) + self.mlm_bias | |||
| mlm_probs = self.softmax(mlm_logits) | |||
| return mlm_probs | |||
| def _dec_head(self, dec_embed): | |||
| dec_logits = torch.matmul(dec_embed, | |||
| self.embedder.token_embedding.weight.T) | |||
| dec_probs = self.softmax(dec_logits) | |||
| return dec_probs | |||
| def _refactor_feature(self, features): | |||
| features = self.pooler(features) if self.with_pool else features | |||
| batch_size = features.size(0) // 2 | |||
| features = \ | |||
| torch.cat( | |||
| [features[:batch_size].unsqueeze(1), features[batch_size:].unsqueeze(1)], | |||
| dim=1 | |||
| ) | |||
| features = F.normalize(features, dim=-1, p=2) | |||
| return features | |||
| def _encoder_network(self, | |||
| input_token, | |||
| input_mask, | |||
| input_pos=None, | |||
| input_type=None, | |||
| input_turn=None): | |||
| embed = self.embedder(input_token, input_pos, input_type, input_turn) | |||
| embed = self.embed_layer_norm(embed) | |||
| mask = self._create_mask(input_mask, auto_regressive=False) | |||
| for layer in self.layers: | |||
| embed = layer(embed, mask, None) | |||
| return embed | |||
| def _encoder_decoder_network(self, | |||
| src_token, | |||
| src_mask, | |||
| tgt_token, | |||
| tgt_mask, | |||
| src_pos=None, | |||
| src_type=None, | |||
| src_turn=None, | |||
| tgt_pos=None, | |||
| tgt_type=None, | |||
| tgt_turn=None): | |||
| src_embed = self.embedder(src_token, src_pos, src_type, src_turn) | |||
| tgt_embed = self.embedder(tgt_token, tgt_pos, tgt_type, tgt_turn) | |||
| embed = torch.cat([src_embed, tgt_embed], dim=1) | |||
| embed = self.embed_layer_norm(embed) | |||
| enc_mask = self._create_mask(src_mask, auto_regressive=False) | |||
| dec_mask = self._create_mask(tgt_mask, auto_regressive=True) | |||
| mask = self._join_mask(enc_mask, dec_mask) | |||
| for layer in self.layers: | |||
| embed = layer(embed, mask, None) | |||
| tgt_len = tgt_token.shape[1] | |||
| enc_embed = embed[:, :-tgt_len] | |||
| dec_embed = embed[:, -tgt_len:] | |||
| return enc_embed, dec_embed | |||
| def _encoder_prompt_decoder_network(self, | |||
| src_token, | |||
| src_mask, | |||
| tgt_token, | |||
| tgt_mask, | |||
| prompt_token, | |||
| prompt_mask, | |||
| src_pos=None, | |||
| src_type=None, | |||
| src_turn=None, | |||
| tgt_pos=None, | |||
| tgt_type=None, | |||
| tgt_turn=None, | |||
| prompt_pos=None, | |||
| prompt_type=None, | |||
| prompt_turn=None): | |||
| src_embed = self.embedder(src_token, src_pos, src_type, src_turn) | |||
| tgt_embed = self.embedder(tgt_token, tgt_pos, tgt_type, tgt_turn) | |||
| prompt_embed = self.embedder(prompt_token, prompt_pos, prompt_type, | |||
| prompt_turn) | |||
| embed = torch.cat([src_embed, prompt_embed, tgt_embed], dim=1) | |||
| embed = self.embed_layer_norm(embed) | |||
| enc_mask = self._create_mask(src_mask, auto_regressive=False) | |||
| dec_mask = self._create_mask( | |||
| torch.cat([prompt_mask, tgt_mask], dim=1), auto_regressive=True) | |||
| mask = self._join_mask(enc_mask, dec_mask) | |||
| for layer in self.layers: | |||
| embed = layer(embed, mask, None) | |||
| src_len = src_token.shape[1] | |||
| tgt_len = tgt_token.shape[1] | |||
| enc_embed = embed[:, :src_len] | |||
| dec_embed = embed[:, -tgt_len:] | |||
| prompt_embed = embed[:, src_len:-tgt_len] | |||
| return enc_embed, dec_embed, prompt_embed | |||
| def _optimize(self, loss, optimizer=None, lr_scheduler=None): | |||
| """ Optimize loss function and update model. """ | |||
| assert optimizer is not None | |||
| optimizer.zero_grad() | |||
| loss.backward() | |||
| if self.grad_clip is not None and self.grad_clip > 0: | |||
| torch.nn.utils.clip_grad_norm_( | |||
| parameters=self.parameters(), max_norm=self.grad_clip) | |||
| optimizer.step() | |||
| if lr_scheduler is not None: | |||
| lr_scheduler.step() | |||
| return | |||
| def _infer(self, | |||
| inputs, | |||
| start_id=None, | |||
| eos_id=None, | |||
| max_gen_len=None, | |||
| prev_input=None): | |||
| """ Real inference process of model. """ | |||
| results = {} | |||
| return results | |||
| UnifiedTransformer.register('UnifiedTransformer') | |||
| @@ -0,0 +1,65 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import torch | |||
| import torch.nn as nn | |||
| class Embedder(nn.Module): | |||
| """ | |||
| Composite embedding layer. | |||
| """ | |||
| def __init__(self, | |||
| hidden_dim, | |||
| num_token_embeddings, | |||
| num_pos_embeddings, | |||
| num_type_embeddings, | |||
| num_turn_embeddings, | |||
| padding_idx=None, | |||
| dropout=0.1, | |||
| pos_trainable=False): | |||
| super(Embedder, self).__init__() | |||
| self.token_embedding = nn.Embedding(num_token_embeddings, hidden_dim) | |||
| self.pos_embedding = nn.Embedding(num_pos_embeddings, hidden_dim) | |||
| self.pos_embedding.weight.requires_grad = pos_trainable | |||
| self.type_embedding = nn.Embedding(num_type_embeddings, hidden_dim) | |||
| self.turn_embedding = nn.Embedding(num_turn_embeddings, hidden_dim) | |||
| self.dropout_layer = nn.Dropout(p=dropout) | |||
| # follow the default xavier_uniform initializer in paddle version | |||
| # otherwise, there are bugs for dec_probs computation in weight typing setting | |||
| # default norm initializer in nn.Embedding in pytorch, which samples larger values | |||
| nn.init.xavier_uniform_(self.token_embedding.weight) | |||
| nn.init.xavier_uniform_(self.pos_embedding.weight) | |||
| nn.init.xavier_uniform_(self.type_embedding.weight) | |||
| nn.init.xavier_uniform_(self.turn_embedding.weight) | |||
| return | |||
| def forward(self, token_inp, pos_inp=None, type_inp=None, turn_inp=None): | |||
| embed = self.token_embedding(token_inp) | |||
| if pos_inp is not None: | |||
| embed += self.pos_embedding(pos_inp) | |||
| if type_inp is not None: | |||
| embed += self.type_embedding(type_inp) | |||
| if turn_inp is not None: | |||
| embed += self.turn_embedding(turn_inp) | |||
| embed = self.dropout_layer(embed) | |||
| return embed | |||
| def main(): | |||
| import numpy as np | |||
| model = Embedder(10, 20, 20, 20, 20) | |||
| token_inp = torch.tensor( | |||
| np.random.randint(0, 19, [10, 10]).astype('int64')) | |||
| pos_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) | |||
| type_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) | |||
| turn_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) | |||
| out = model(token_inp, pos_inp, type_inp, turn_inp) | |||
| print(out) | |||
| if __name__ == '__main__': | |||
| main() | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import torch | |||
| import torch.nn as nn | |||
| class FeedForward(nn.Module): | |||
| """ | |||
| Positional feed forward layer. | |||
| """ | |||
| def __init__(self, hidden_dim, inner_dim, dropout): | |||
| super(FeedForward, self).__init__() | |||
| self.hidden_dim = hidden_dim | |||
| self.inner_dim = inner_dim | |||
| self.linear_hidden = nn.Sequential( | |||
| nn.Linear(hidden_dim, inner_dim), nn.GELU()) | |||
| self.linear_out = nn.Linear(inner_dim, hidden_dim) | |||
| self.dropout_layer = nn.Dropout(p=dropout) | |||
| return | |||
| def forward(self, x): | |||
| out = self.linear_hidden(x) | |||
| out = self.dropout_layer(out) | |||
| out = self.linear_out(out) | |||
| return out | |||
| def main(): | |||
| import numpy as np | |||
| model = FeedForward(10, 20, 0.5) | |||
| inp = np.random.rand(2, 3, 10).astype('float32') | |||
| inp = torch.tensor(inp) | |||
| out = model(inp) | |||
| print(out) | |||
| if __name__ == '__main__': | |||
| main() | |||
| @@ -0,0 +1,62 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn.functional as F | |||
| def unsqueeze(input, dims): | |||
| """ Implement multi-dimension unsqueeze function. """ | |||
| if isinstance(dims, (list, tuple)): | |||
| dims = [ | |||
| dim if dim >= 0 else dim + len(input.shape) + 1 for dim in dims | |||
| ] | |||
| dims = sorted(dims, reverse=True) | |||
| shape = list(input.shape) | |||
| for dim in dims: | |||
| shape.insert(dim, 1) | |||
| return torch.reshape(input, shape) | |||
| elif isinstance(dims, int): | |||
| return input.unsqueeze(dims) | |||
| else: | |||
| raise ValueError('Warning: type(dims) must in (list, tuple, int)!') | |||
| def gumbel_softmax(input, tau=1, eps=1e-10): | |||
| """ Basic implement of gumbel_softmax. """ | |||
| U = torch.tensor(np.random.rand(*input.shape)) | |||
| gumbel = 0.0 - torch.log(eps - torch.log(U + eps)) | |||
| y = input + gumbel | |||
| return F.softmax(y / tau) | |||
| def equal(x, y, dtype=None): | |||
| """ Implement equal in dygraph mode. (paddle) """ | |||
| if dtype is None: | |||
| dtype = 'float32' | |||
| if isinstance(x, torch.Tensor): | |||
| x = x.numpy() | |||
| if isinstance(y, torch.Tensor): | |||
| y = y.numpy() | |||
| out = np.equal(x, y).astype(dtype) | |||
| return torch.tensor(out) | |||
| def not_equal(x, y, dtype=None): | |||
| """ Implement not_equal in dygraph mode. (paddle) """ | |||
| return 1 - equal(x, y, dtype) | |||
| if __name__ == '__main__': | |||
| a = torch.tensor([[1, 1], [3, 4]]) | |||
| b = torch.tensor([[1, 1], [3, 4]]) | |||
| c = torch.equal(a, a) | |||
| c1 = equal(a, 3) | |||
| d = 1 - torch.not_equal(a, 3).float() | |||
| print(c) | |||
| print(c1) | |||
| print(d) | |||
| e = F.gumbel_softmax(a) | |||
| f = a.unsqueeze(a) | |||
| g = unsqueeze(a, dims=[0, 0, 1]) | |||
| print(g, g.shape) | |||
| @@ -0,0 +1,105 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import torch | |||
| import torch.nn as nn | |||
| class MultiheadAttention(nn.Module): | |||
| """ | |||
| Multi head attention layer. | |||
| """ | |||
| def __init__(self, hidden_dim, num_heads, dropout): | |||
| assert hidden_dim % num_heads == 0 | |||
| super(MultiheadAttention, self).__init__() | |||
| self.hidden_dim = hidden_dim | |||
| self.num_heads = num_heads | |||
| self.head_dim = hidden_dim // num_heads | |||
| self.scale = self.head_dim**-0.5 | |||
| self.linear_qkv = nn.Linear(hidden_dim, hidden_dim * 3) | |||
| self.linear_out = nn.Linear(hidden_dim, hidden_dim) | |||
| self.dropout_layer = nn.Dropout(p=dropout) | |||
| self.softmax = nn.Softmax(dim=-1) | |||
| return | |||
| def _split_heads(self, x, is_key=False): | |||
| x = x.reshape(x.size(0), x.size(1), self.num_heads, self.head_dim) | |||
| x = x.permute(0, 2, 3, 1) if is_key else x.permute(0, 2, 1, 3) | |||
| return x | |||
| def _merge_heads(self, x): | |||
| x = x.permute(0, 2, 1, 3) | |||
| x = x.reshape(x.size(0), x.size(1), self.hidden_dim) | |||
| return x | |||
| def _attn(self, query, key, value, mask): | |||
| # shape: [batch_size, num_head, seq_len, seq_len] | |||
| scores = torch.matmul(query, key) | |||
| scores = scores * self.scale | |||
| if mask is not None: | |||
| mask = mask.unsqueeze(1) | |||
| mask = mask.repeat(1, self.num_heads, 1, 1) | |||
| scores.masked_fill_( | |||
| mask.bool(), | |||
| float('-inf')) # scores = (1 - mask) * scores + mask * (-1e10) | |||
| attn = self.softmax(scores) | |||
| attn = self.dropout_layer(attn) | |||
| if mask is not None: | |||
| ''' | |||
| mask: [batch size, num_heads, seq_len, seq_len] | |||
| >>> F.softmax([-1e10, -100, -100]) | |||
| >>> [0.00, 0.50, 0.50] | |||
| >>> F.softmax([-1e10, -1e10, -1e10]) | |||
| >>> [0.33, 0.33, 0.33] | |||
| ==> [0.00, 0.00, 0.00] | |||
| ''' | |||
| attn.masked_fill_(mask.bool(), 0.) # attn = (1 - mask) * attn | |||
| out = torch.matmul(attn, value) | |||
| return out | |||
| def forward(self, inp, mask=None, cache=None): | |||
| """ Forward process of self attention. """ | |||
| # shape: [batch_size, seq_len, 3 * hidden_dim] | |||
| qkv = self.linear_qkv(inp) | |||
| query, key, value = torch.split(qkv, self.hidden_dim, dim=2) | |||
| # shape: [batch_size, num_head, seq_len, head_dim] | |||
| query = self._split_heads(query) | |||
| # shape: [batch_size, num_head, head_dim, seq_len] | |||
| key = self._split_heads(key, is_key=True) | |||
| # shape: [batch_size, num_head, seq_len, head_dim] | |||
| value = self._split_heads(value) | |||
| if cache is not None: | |||
| if 'key' in cache and 'value' in cache: | |||
| key = torch.cat([cache['key'], key], dim=3) | |||
| value = torch.cat([cache['value'], value], dim=2) | |||
| cache['key'] = key | |||
| cache['value'] = value | |||
| out = self._attn(query, key, value, mask) | |||
| out = self._merge_heads(out) | |||
| out = self.linear_out(out) | |||
| return out | |||
| def main(): | |||
| import numpy as np | |||
| model = MultiheadAttention(10, 2, 0.5) | |||
| inp = np.random.rand(2, 3, 10).astype('float32') | |||
| inp = torch.tensor(inp) | |||
| mask = (np.random.rand(2, 3, 3) > 0.5).astype('float32') | |||
| mask = torch.tensor(mask) | |||
| out = model(inp, mask=mask, cache=None) | |||
| print(out) | |||
| if __name__ == '__main__': | |||
| main() | |||
| @@ -0,0 +1,70 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import torch | |||
| import torch.nn as nn | |||
| from .feedforward import FeedForward | |||
| from .multihead_attention import MultiheadAttention | |||
| class TransformerBlock(nn.Module): | |||
| """ | |||
| Transformer block module. | |||
| """ | |||
| def __init__(self, hidden_dim, num_heads, dropout, attn_dropout, | |||
| ff_dropout): | |||
| super(TransformerBlock, self).__init__() | |||
| self.attn = MultiheadAttention( | |||
| hidden_dim=hidden_dim, num_heads=num_heads, dropout=attn_dropout) | |||
| self.attn_norm = nn.LayerNorm( | |||
| normalized_shape=hidden_dim, eps=1e-12, elementwise_affine=True) | |||
| self.ff = FeedForward( | |||
| hidden_dim=hidden_dim, | |||
| inner_dim=4 * hidden_dim, | |||
| dropout=ff_dropout) | |||
| self.ff_norm = nn.LayerNorm( | |||
| normalized_shape=hidden_dim, eps=1e-12, elementwise_affine=True) | |||
| self.dropout_layer = nn.Dropout(p=dropout) | |||
| return | |||
| def forward(self, inp, mask=None, cache=None): | |||
| """ | |||
| Forward process on one transformer layer. | |||
| @param : x | |||
| @type : Variable(shape: [batch_size, seq_len, hidden_size]) | |||
| @param : memory | |||
| @type : Variable(shape: [batch_size, seq_len, hidden_size]) | |||
| @param : mask | |||
| @param : cache | |||
| """ | |||
| attn_out = self.attn(inp, mask, cache) | |||
| attn_out = self.dropout_layer(attn_out) | |||
| attn_out = self.attn_norm(attn_out + inp) | |||
| ff_out = self.ff(attn_out) | |||
| ff_out = self.dropout_layer(ff_out) | |||
| ff_out = self.ff_norm(ff_out + attn_out) | |||
| return ff_out | |||
| def main(): | |||
| import numpy as np | |||
| model = TransformerBlock(10, 2, 0.5, 0.5, 0.5) | |||
| inp = np.random.rand(2, 3, 10).astype('float32') | |||
| inp = torch.tensor(inp) | |||
| mask = (np.random.rand(2, 3, 3) > 0.5).astype('float32') | |||
| mask = torch.tensor(mask) | |||
| out = model(inp, mask=mask, cache=None) | |||
| print(out) | |||
| if __name__ == '__main__': | |||
| main() | |||
| @@ -21,6 +21,12 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.sentence_similarity: | |||
| (Pipelines.sentence_similarity, | |||
| 'damo/nlp_structbert_sentence-similarity_chinese-base'), | |||
| Tasks.nli: (Pipelines.nli, 'damo/nlp_structbert_nli_chinese-base'), | |||
| Tasks.sentiment_classification: | |||
| (Pipelines.sentiment_classification, | |||
| 'damo/nlp_structbert_sentiment-classification_chinese-base'), | |||
| Tasks.text_classification: ('bert-sentiment-analysis', | |||
| 'damo/bert-base-sst2'), | |||
| Tasks.image_matting: (Pipelines.image_matting, | |||
| 'damo/cv_unet_image-matting'), | |||
| Tasks.text_classification: (Pipelines.sentiment_analysis, | |||
| @@ -30,6 +36,11 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.zero_shot_classification: | |||
| (Pipelines.zero_shot_classification, | |||
| 'damo/nlp_structbert_zero-shot-classification_chinese-base'), | |||
| Tasks.dialog_intent_prediction: | |||
| (Pipelines.dialog_intent_prediction, | |||
| 'damo/nlp_space_dialog-intent-prediction'), | |||
| Tasks.dialog_modeling: (Pipelines.dialog_modeling, | |||
| 'damo/nlp_space_dialog-modeling'), | |||
| Tasks.image_captioning: (Pipelines.image_caption, | |||
| 'damo/ofa_image-caption_coco_large_en'), | |||
| Tasks.image_generation: | |||
| @@ -1,6 +1,10 @@ | |||
| try: | |||
| from .dialog_intent_prediction_pipeline import * # noqa F403 | |||
| from .dialog_modeling_pipeline import * # noqa F403 | |||
| from .fill_mask_pipeline import * # noqa F403 | |||
| from .nli_pipeline import * # noqa F403 | |||
| from .sentence_similarity_pipeline import * # noqa F403 | |||
| from .sentiment_classification_pipeline import * # noqa F403 | |||
| from .sequence_classification_pipeline import * # noqa F403 | |||
| from .text_generation_pipeline import * # noqa F403 | |||
| from .word_segmentation_pipeline import * # noqa F403 | |||
| @@ -0,0 +1,53 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict | |||
| from ...metainfo import Pipelines | |||
| from ...models.nlp import SpaceForDialogIntent | |||
| from ...preprocessors import DialogIntentPredictionPreprocessor | |||
| from ...utils.constant import Tasks | |||
| from ..base import Pipeline | |||
| from ..builder import PIPELINES | |||
| from ..outputs import OutputKeys | |||
| __all__ = ['DialogIntentPredictionPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.dialog_intent_prediction, | |||
| module_name=Pipelines.dialog_intent_prediction) | |||
| class DialogIntentPredictionPipeline(Pipeline): | |||
| def __init__(self, model: SpaceForDialogIntent, | |||
| preprocessor: DialogIntentPredictionPreprocessor, **kwargs): | |||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||
| Args: | |||
| model (SequenceClassificationModel): a model instance | |||
| preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||
| """ | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| self.model = model | |||
| self.categories = preprocessor.categories | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| import numpy as np | |||
| pred = inputs['pred'] | |||
| pos = np.where(pred == np.max(pred)) | |||
| result = { | |||
| OutputKeys.PREDICTION: pred, | |||
| OutputKeys.LABEL_POS: pos[0], | |||
| OutputKeys.LABEL: self.categories[pos[0][0]] | |||
| } | |||
| return result | |||
| @@ -0,0 +1,49 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict, Optional | |||
| from ...metainfo import Pipelines | |||
| from ...models.nlp import SpaceForDialogModeling | |||
| from ...preprocessors import DialogModelingPreprocessor | |||
| from ...utils.constant import Tasks | |||
| from ..base import Pipeline, Tensor | |||
| from ..builder import PIPELINES | |||
| from ..outputs import OutputKeys | |||
| __all__ = ['DialogModelingPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.dialog_modeling, module_name=Pipelines.dialog_modeling) | |||
| class DialogModelingPipeline(Pipeline): | |||
| def __init__(self, model: SpaceForDialogModeling, | |||
| preprocessor: DialogModelingPreprocessor, **kwargs): | |||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||
| Args: | |||
| model (SequenceClassificationModel): a model instance | |||
| preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||
| """ | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| self.model = model | |||
| self.preprocessor = preprocessor | |||
| def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| sys_rsp = self.preprocessor.text_field.tokenizer.convert_ids_to_tokens( | |||
| inputs['resp']) | |||
| assert len(sys_rsp) > 2 | |||
| sys_rsp = sys_rsp[1:len(sys_rsp) - 1] | |||
| inputs[OutputKeys.RESPONSE] = sys_rsp | |||
| return inputs | |||
| @@ -1,5 +1,7 @@ | |||
| import os | |||
| from typing import Dict, Optional, Union | |||
| from typing import Any, Dict, Optional, Union | |||
| import torch | |||
| from ...metainfo import Pipelines | |||
| from ...models import Model | |||
| @@ -21,6 +23,7 @@ class FillMaskPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[MaskedLanguageModelBase, str], | |||
| preprocessor: Optional[FillMaskPreprocessor] = None, | |||
| first_sequence='sentense', | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create a nlp fill mask pipeline for prediction | |||
| @@ -30,12 +33,16 @@ class FillMaskPipeline(Pipeline): | |||
| """ | |||
| fill_mask_model = model if isinstance( | |||
| model, MaskedLanguageModelBase) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = FillMaskPreprocessor( | |||
| fill_mask_model.model_dir, | |||
| first_sequence='sentence', | |||
| first_sequence=first_sequence, | |||
| second_sequence=None) | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| fill_mask_model.eval() | |||
| super().__init__( | |||
| model=fill_mask_model, preprocessor=preprocessor, **kwargs) | |||
| self.preprocessor = preprocessor | |||
| self.config = Config.from_file( | |||
| os.path.join(fill_mask_model.model_dir, ModelFile.CONFIGURATION)) | |||
| @@ -63,6 +70,11 @@ class FillMaskPipeline(Pipeline): | |||
| } | |||
| } | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return super().forward(inputs, **forward_params) | |||
| def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| """process the prediction results | |||
| @@ -0,0 +1,73 @@ | |||
| import uuid | |||
| from typing import Any, Dict, Union | |||
| import numpy as np | |||
| import torch | |||
| from ...metainfo import Pipelines | |||
| from ...models import Model | |||
| from ...models.nlp import SbertForNLI | |||
| from ...preprocessors import NLIPreprocessor | |||
| from ...utils.constant import Tasks | |||
| from ..base import Pipeline | |||
| from ..builder import PIPELINES | |||
| from ..outputs import OutputKeys | |||
| __all__ = ['NLIPipeline'] | |||
| @PIPELINES.register_module(Tasks.nli, module_name=Pipelines.nli) | |||
| class NLIPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[SbertForNLI, str], | |||
| preprocessor: NLIPreprocessor = None, | |||
| first_sequence='first_sequence', | |||
| second_sequence='second_sequence', | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||
| Args: | |||
| model (SbertForNLI): a model instance | |||
| preprocessor (NLIPreprocessor): a preprocessor instance | |||
| """ | |||
| assert isinstance(model, str) or isinstance(model, SbertForNLI), \ | |||
| 'model must be a single str or SbertForNLI' | |||
| model = model if isinstance( | |||
| model, SbertForNLI) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = NLIPreprocessor( | |||
| model.model_dir, | |||
| first_sequence=first_sequence, | |||
| second_sequence=second_sequence) | |||
| model.eval() | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| assert len(model.id2label) > 0 | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return super().forward(inputs, **forward_params) | |||
| def postprocess(self, | |||
| inputs: Dict[str, Any], | |||
| topk: int = 5) -> Dict[str, str]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| probs = inputs['probabilities'][0] | |||
| num_classes = probs.shape[0] | |||
| topk = min(topk, num_classes) | |||
| top_indices = np.argpartition(probs, -topk)[-topk:] | |||
| cls_ids = top_indices[np.argsort(probs[top_indices])] | |||
| probs = probs[cls_ids].tolist() | |||
| cls_names = [self.model.id2label[cid] for cid in cls_ids] | |||
| return {OutputKeys.SCORES: probs, OutputKeys.LABELS: cls_names} | |||
| @@ -1,12 +1,13 @@ | |||
| from typing import Any, Dict, Union | |||
| import numpy as np | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.nlp import SbertForSentenceSimilarity | |||
| from modelscope.preprocessors import SequenceClassificationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from ...metainfo import Pipelines | |||
| from ...models import Model | |||
| from ...models.nlp import SbertForSentenceSimilarity | |||
| from ...preprocessors import SequenceClassificationPreprocessor | |||
| from ...utils.constant import Tasks | |||
| from ..base import Input, Pipeline | |||
| from ..builder import PIPELINES | |||
| from ..outputs import OutputKeys | |||
| @@ -19,8 +20,10 @@ __all__ = ['SentenceSimilarityPipeline'] | |||
| class SentenceSimilarityPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[SbertForSentenceSimilarity, str], | |||
| model: Union[Model, str], | |||
| preprocessor: SequenceClassificationPreprocessor = None, | |||
| first_sequence='first_sequence', | |||
| second_sequence='second_sequence', | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create a nlp sentence similarity pipeline for prediction | |||
| @@ -36,14 +39,21 @@ class SentenceSimilarityPipeline(Pipeline): | |||
| if preprocessor is None: | |||
| preprocessor = SequenceClassificationPreprocessor( | |||
| sc_model.model_dir, | |||
| first_sequence='first_sequence', | |||
| second_sequence='second_sequence') | |||
| first_sequence=first_sequence, | |||
| second_sequence=second_sequence) | |||
| sc_model.eval() | |||
| super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | |||
| assert hasattr(self.model, 'id2label'), \ | |||
| 'id2label map should be initalizaed in init function.' | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return super().forward(inputs, **forward_params) | |||
| def postprocess(self, inputs: Dict[str, Any], | |||
| **postprocess_params) -> Dict[str, str]: | |||
| """process the prediction results | |||
| Args: | |||
| @@ -0,0 +1,78 @@ | |||
| import os | |||
| import uuid | |||
| from typing import Any, Dict, Union | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| from ...metainfo import Pipelines | |||
| from ...models import Model | |||
| from ...models.nlp import SbertForSentimentClassification | |||
| from ...preprocessors import SentimentClassificationPreprocessor | |||
| from ...utils.constant import Tasks | |||
| from ..base import Input, Pipeline | |||
| from ..builder import PIPELINES | |||
| from ..outputs import OutputKeys | |||
| __all__ = ['SentimentClassificationPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.sentiment_classification, | |||
| module_name=Pipelines.sentiment_classification) | |||
| class SentimentClassificationPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[SbertForSentimentClassification, str], | |||
| preprocessor: SentimentClassificationPreprocessor = None, | |||
| first_sequence='first_sequence', | |||
| second_sequence='second_sequence', | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||
| Args: | |||
| model (SbertForSentimentClassification): a model instance | |||
| preprocessor (SentimentClassificationPreprocessor): a preprocessor instance | |||
| """ | |||
| assert isinstance(model, str) or isinstance(model, SbertForSentimentClassification), \ | |||
| 'model must be a single str or SbertForSentimentClassification' | |||
| model = model if isinstance( | |||
| model, | |||
| SbertForSentimentClassification) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = SentimentClassificationPreprocessor( | |||
| model.model_dir, | |||
| first_sequence=first_sequence, | |||
| second_sequence=second_sequence) | |||
| model.eval() | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| assert len(model.id2label) > 0 | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return super().forward(inputs, **forward_params) | |||
| def postprocess(self, | |||
| inputs: Dict[str, Any], | |||
| topk: int = 5) -> Dict[str, str]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| probs = inputs['probabilities'][0] | |||
| num_classes = probs.shape[0] | |||
| topk = min(topk, num_classes) | |||
| top_indices = np.argpartition(probs, -topk)[-topk:] | |||
| cls_ids = top_indices[np.argsort(probs[top_indices])] | |||
| probs = probs[cls_ids].tolist() | |||
| cls_names = [self.model.id2label[cid] for cid in cls_ids] | |||
| return {OutputKeys.SCORES: probs, OutputKeys.LABELS: cls_names} | |||
| @@ -1,10 +1,12 @@ | |||
| from typing import Dict, Optional, Union | |||
| from typing import Any, Dict, Optional, Union | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import PalmForTextGeneration | |||
| from modelscope.preprocessors import TextGenerationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| import torch | |||
| from ...metainfo import Pipelines | |||
| from ...models import Model | |||
| from ...models.nlp import PalmForTextGeneration | |||
| from ...preprocessors import TextGenerationPreprocessor | |||
| from ...utils.constant import Tasks | |||
| from ..base import Pipeline, Tensor | |||
| from ..builder import PIPELINES | |||
| from ..outputs import OutputKeys | |||
| @@ -34,10 +36,17 @@ class TextGenerationPipeline(Pipeline): | |||
| model.tokenizer, | |||
| first_sequence='sentence', | |||
| second_sequence=None) | |||
| model.eval() | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| self.tokenizer = model.tokenizer | |||
| def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return super().forward(inputs, **forward_params) | |||
| def postprocess(self, inputs: Dict[str, Tensor], | |||
| **postprocess_params) -> Dict[str, str]: | |||
| """process the prediction results | |||
| Args: | |||
| @@ -1,10 +1,12 @@ | |||
| from typing import Any, Dict, Optional, Union | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import StructBertForTokenClassification | |||
| from modelscope.preprocessors import TokenClassifcationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| import torch | |||
| from ...metainfo import Pipelines | |||
| from ...models import Model | |||
| from ...models.nlp import SbertForTokenClassification | |||
| from ...preprocessors import TokenClassifcationPreprocessor | |||
| from ...utils.constant import Tasks | |||
| from ..base import Pipeline, Tensor | |||
| from ..builder import PIPELINES | |||
| from ..outputs import OutputKeys | |||
| @@ -17,7 +19,7 @@ __all__ = ['WordSegmentationPipeline'] | |||
| class WordSegmentationPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[StructBertForTokenClassification, str], | |||
| model: Union[SbertForTokenClassification, str], | |||
| preprocessor: Optional[TokenClassifcationPreprocessor] = None, | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create a nlp word segmentation pipeline for prediction | |||
| @@ -28,15 +30,23 @@ class WordSegmentationPipeline(Pipeline): | |||
| """ | |||
| model = model if isinstance( | |||
| model, | |||
| StructBertForTokenClassification) else Model.from_pretrained(model) | |||
| SbertForTokenClassification) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = TokenClassifcationPreprocessor(model.model_dir) | |||
| model.eval() | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| self.tokenizer = preprocessor.tokenizer | |||
| self.config = model.config | |||
| assert len(self.config.id2label) > 0 | |||
| self.id2label = self.config.id2label | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return super().forward(inputs, **forward_params) | |||
| def postprocess(self, inputs: Dict[str, Any], | |||
| **postprocess_params) -> Dict[str, str]: | |||
| """process the prediction results | |||
| Args: | |||
| @@ -5,7 +5,9 @@ from modelscope.utils.constant import Tasks | |||
| class OutputKeys(object): | |||
| SCORES = 'scores' | |||
| LABEL = 'label' | |||
| LABELS = 'labels' | |||
| LABEL_POS = 'label_pos' | |||
| POSES = 'poses' | |||
| CAPTION = 'caption' | |||
| BOXES = 'boxes' | |||
| @@ -16,6 +18,8 @@ class OutputKeys(object): | |||
| OUTPUT_PCM = 'output_pcm' | |||
| IMG_EMBEDDING = 'img_embedding' | |||
| TEXT_EMBEDDING = 'text_embedding' | |||
| RESPONSE = 'response' | |||
| PREDICTION = 'prediction' | |||
| TASK_OUTPUTS = { | |||
| @@ -119,6 +123,13 @@ TASK_OUTPUTS = { | |||
| # } | |||
| Tasks.sentence_similarity: [OutputKeys.SCORES, OutputKeys.LABELS], | |||
| # sentiment classification result for single sample | |||
| # { | |||
| # "labels": ["happy", "sad", "calm", "angry"], | |||
| # "scores": [0.9, 0.1, 0.05, 0.05] | |||
| # } | |||
| Tasks.sentiment_classification: [OutputKeys.SCORES, OutputKeys.LABELS], | |||
| # zero-shot classification result for single sample | |||
| # { | |||
| # "scores": [0.9, 0.1, 0.05, 0.05] | |||
| @@ -126,6 +137,39 @@ TASK_OUTPUTS = { | |||
| # } | |||
| Tasks.zero_shot_classification: [OutputKeys.SCORES, OutputKeys.LABELS], | |||
| # nli result for single sample | |||
| # { | |||
| # "labels": ["happy", "sad", "calm", "angry"], | |||
| # "scores": [0.9, 0.1, 0.05, 0.05] | |||
| # } | |||
| Tasks.nli: [OutputKeys.SCORES, OutputKeys.LABELS], | |||
| # {'pred': array([2.62349960e-03, 4.12110658e-03, 4.12748595e-05, 3.77560973e-05, | |||
| # 1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04, | |||
| # 6.45841064e-05, 1.15997791e-04, 5.11605394e-05, 9.87020373e-01, | |||
| # 2.66957268e-05, 4.72324500e-05, 9.74208378e-05, 4.18022355e-05, | |||
| # 2.97343540e-05, 5.81317654e-05, 5.44203431e-05, 6.28319322e-05, | |||
| # 7.34537680e-05, 6.61411541e-05, 3.62534920e-05, 8.58885178e-05, | |||
| # 8.24327726e-05, 4.66077945e-05, 5.32869453e-05, 4.16190960e-05, | |||
| # 5.97518992e-05, 3.92273068e-05, 3.44069012e-05, 9.92335918e-05, | |||
| # 9.25978165e-05, 6.26462061e-05, 3.32317031e-05, 1.32061413e-03, | |||
| # 2.01607945e-05, 3.36636294e-05, 3.99156743e-05, 5.84108493e-05, | |||
| # 2.53432900e-05, 4.95731190e-04, 2.64443643e-05, 4.46992999e-05, | |||
| # 2.42672231e-05, 4.75615161e-05, 2.66230145e-05, 4.00083954e-05, | |||
| # 2.90536875e-04, 4.23891543e-05, 8.63691166e-05, 4.98188965e-05, | |||
| # 3.47019341e-05, 4.52718523e-05, 4.20905781e-05, 5.50173208e-05, | |||
| # 4.92360487e-05, 3.56021264e-05, 2.13957210e-05, 6.17428886e-05, | |||
| # 1.43893281e-04, 7.32152112e-05, 2.91354867e-04, 2.46623786e-05, | |||
| # 3.61441926e-05, 3.38475402e-05, 3.44323053e-05, 5.70138109e-05, | |||
| # 4.31488479e-05, 4.94503947e-05, 4.30105974e-05, 1.00963116e-04, | |||
| # 2.82062047e-05, 1.15582036e-04, 4.48261271e-05, 3.99339879e-05, | |||
| # 7.27692823e-05], dtype=float32), 'label_pos': array([11]), 'label': 'lost_or_stolen_card'} | |||
| Tasks.dialog_intent_prediction: | |||
| [OutputKeys.PREDICTION, OutputKeys.LABEL_POS, OutputKeys.LABEL], | |||
| # sys : ['you', 'are', 'welcome', '.', 'have', 'a', 'great', 'day', '!'] | |||
| Tasks.dialog_modeling: [OutputKeys.RESPONSE], | |||
| # ============ audio tasks =================== | |||
| # audio processed for single file in PCM format | |||
| @@ -11,6 +11,8 @@ try: | |||
| from .audio import LinearAECAndFbank | |||
| from .multi_modal import * # noqa F403 | |||
| from .nlp import * # noqa F403 | |||
| from .space.dialog_intent_prediction_preprocessor import * # noqa F403 | |||
| from .space.dialog_modeling_preprocessor import * # noqa F403 | |||
| except ModuleNotFoundError as e: | |||
| if str(e) == "No module named 'tensorflow'": | |||
| pass | |||
| @@ -5,15 +5,16 @@ from typing import Any, Dict, Union | |||
| from transformers import AutoTokenizer | |||
| from modelscope.metainfo import Preprocessors | |||
| from modelscope.utils.constant import Fields, InputFields | |||
| from modelscope.utils.type_assert import type_assert | |||
| from ..metainfo import Models, Preprocessors | |||
| from ..utils.constant import Fields, InputFields | |||
| from ..utils.type_assert import type_assert | |||
| from .base import Preprocessor | |||
| from .builder import PREPROCESSORS | |||
| __all__ = [ | |||
| 'Tokenize', 'SequenceClassificationPreprocessor', | |||
| 'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor', | |||
| 'NLIPreprocessor', 'SentimentClassificationPreprocessor', | |||
| 'FillMaskPreprocessor', 'ZeroShotClassificationPreprocessor' | |||
| ] | |||
| @@ -32,6 +33,140 @@ class Tokenize(Preprocessor): | |||
| return data | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.nli_tokenizer) | |||
| class NLIPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| from sofa import SbertTokenizer | |||
| self.model_dir: str = model_dir | |||
| self.first_sequence: str = kwargs.pop('first_sequence', | |||
| 'first_sequence') | |||
| self.second_sequence = kwargs.pop('second_sequence', 'second_sequence') | |||
| self.sequence_length = kwargs.pop('sequence_length', 128) | |||
| self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) | |||
| @type_assert(object, tuple) | |||
| def __call__(self, data: tuple) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data (tuple): [sentence1, sentence2] | |||
| sentence1 (str): a sentence | |||
| Example: | |||
| 'you are so handsome.' | |||
| sentence2 (str): a sentence | |||
| Example: | |||
| 'you are so beautiful.' | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| sentence1, sentence2 = data | |||
| new_data = { | |||
| self.first_sequence: sentence1, | |||
| self.second_sequence: sentence2 | |||
| } | |||
| # preprocess the data for the model input | |||
| rst = { | |||
| 'id': [], | |||
| 'input_ids': [], | |||
| 'attention_mask': [], | |||
| 'token_type_ids': [] | |||
| } | |||
| max_seq_length = self.sequence_length | |||
| text_a = new_data[self.first_sequence] | |||
| text_b = new_data[self.second_sequence] | |||
| feature = self.tokenizer( | |||
| text_a, | |||
| text_b, | |||
| padding=False, | |||
| truncation=True, | |||
| max_length=max_seq_length) | |||
| rst['id'].append(new_data.get('id', str(uuid.uuid4()))) | |||
| rst['input_ids'].append(feature['input_ids']) | |||
| rst['attention_mask'].append(feature['attention_mask']) | |||
| rst['token_type_ids'].append(feature['token_type_ids']) | |||
| return rst | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.sen_cls_tokenizer) | |||
| class SentimentClassificationPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| from sofa import SbertTokenizer | |||
| self.model_dir: str = model_dir | |||
| self.first_sequence: str = kwargs.pop('first_sequence', | |||
| 'first_sequence') | |||
| self.second_sequence = kwargs.pop('second_sequence', 'second_sequence') | |||
| self.sequence_length = kwargs.pop('sequence_length', 128) | |||
| self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) | |||
| @type_assert(object, str) | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data (str): a sentence | |||
| Example: | |||
| 'you are so handsome.' | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| new_data = {self.first_sequence: data} | |||
| # preprocess the data for the model input | |||
| rst = { | |||
| 'id': [], | |||
| 'input_ids': [], | |||
| 'attention_mask': [], | |||
| 'token_type_ids': [] | |||
| } | |||
| max_seq_length = self.sequence_length | |||
| text_a = new_data[self.first_sequence] | |||
| text_b = new_data.get(self.second_sequence, None) | |||
| feature = self.tokenizer( | |||
| text_a, | |||
| text_b, | |||
| padding='max_length', | |||
| truncation=True, | |||
| max_length=max_seq_length) | |||
| rst['id'].append(new_data.get('id', str(uuid.uuid4()))) | |||
| rst['input_ids'].append(feature['input_ids']) | |||
| rst['attention_mask'].append(feature['attention_mask']) | |||
| rst['token_type_ids'].append(feature['token_type_ids']) | |||
| return rst | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.bert_seq_cls_tokenizer) | |||
| class SequenceClassificationPreprocessor(Preprocessor): | |||
| @@ -178,7 +313,6 @@ class TextGenerationPreprocessor(Preprocessor): | |||
| rst['input_ids'].append(feature['input_ids']) | |||
| rst['attention_mask'].append(feature['attention_mask']) | |||
| return {k: torch.tensor(v) for k, v in rst.items()} | |||
| @@ -241,7 +375,7 @@ class FillMaskPreprocessor(Preprocessor): | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.sbert_token_cls_tokenizer) | |||
| Fields.nlp, module_name=Preprocessors.token_cls_tokenizer) | |||
| class TokenClassifcationPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| @@ -269,6 +403,7 @@ class TokenClassifcationPreprocessor(Preprocessor): | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| # preprocess the data for the model input | |||
| text = data.replace(' ', '').strip() | |||
| @@ -0,0 +1,57 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from typing import Any, Dict | |||
| import json | |||
| from ...metainfo import Preprocessors | |||
| from ...utils.config import Config | |||
| from ...utils.constant import Fields, ModelFile | |||
| from ...utils.type_assert import type_assert | |||
| from ..base import Preprocessor | |||
| from ..builder import PREPROCESSORS | |||
| from .fields.intent_field import IntentBPETextField | |||
| __all__ = ['DialogIntentPredictionPreprocessor'] | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.dialog_intent_preprocessor) | |||
| class DialogIntentPredictionPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| self.model_dir: str = model_dir | |||
| self.config = Config.from_file( | |||
| os.path.join(self.model_dir, ModelFile.CONFIGURATION)) | |||
| self.text_field = IntentBPETextField( | |||
| self.model_dir, config=self.config) | |||
| self.categories = None | |||
| with open(os.path.join(self.model_dir, 'categories.json'), 'r') as f: | |||
| self.categories = json.load(f) | |||
| assert len(self.categories) == 77 | |||
| @type_assert(object, str) | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data (str): a sentence | |||
| Example: | |||
| 'you are so handsome.' | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| samples = self.text_field.preprocessor([data]) | |||
| samples, _ = self.text_field.collate_fn_multi_turn(samples) | |||
| return samples | |||
| @@ -0,0 +1,51 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from typing import Any, Dict | |||
| from ...metainfo import Preprocessors | |||
| from ...utils.config import Config | |||
| from ...utils.constant import Fields, ModelFile | |||
| from ...utils.type_assert import type_assert | |||
| from ..base import Preprocessor | |||
| from ..builder import PREPROCESSORS | |||
| from .fields.gen_field import MultiWOZBPETextField | |||
| __all__ = ['DialogModelingPreprocessor'] | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.dialog_modeling_preprocessor) | |||
| class DialogModelingPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| self.model_dir: str = model_dir | |||
| self.config = Config.from_file( | |||
| os.path.join(self.model_dir, ModelFile.CONFIGURATION)) | |||
| self.text_field = MultiWOZBPETextField( | |||
| self.model_dir, config=self.config) | |||
| @type_assert(object, Dict) | |||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data (str): a sentence | |||
| Example: | |||
| 'you are so handsome.' | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| user_ids = self.text_field.get_ids(data['user_input']) | |||
| data['user'] = user_ids | |||
| return data | |||
| @@ -0,0 +1,675 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import random | |||
| from collections import OrderedDict | |||
| from itertools import chain | |||
| import numpy as np | |||
| from ....utils.nlp.space import ontology, utils | |||
| from ....utils.nlp.space.db_ops import MultiWozDB | |||
| from ....utils.nlp.space.utils import list2np | |||
| from ..tokenizer import Tokenizer | |||
| class BPETextField(object): | |||
| pad_token = '[PAD]' | |||
| bos_token = '[BOS]' | |||
| eos_token = '[EOS]' | |||
| unk_token = '[UNK]' | |||
| sos_u_token = '<sos_u>' | |||
| eos_u_token = '<eos_u>' | |||
| sos_b_token = '<sos_b>' | |||
| eos_b_token = '<eos_b>' | |||
| sos_d_token = '<sos_d>' | |||
| eos_d_token = '<eos_d>' | |||
| sos_a_token = '<sos_a>' | |||
| eos_a_token = '<eos_a>' | |||
| sos_db_token = '<sos_db>' | |||
| eos_db_token = '<eos_db>' | |||
| sos_r_token = '<sos_r>' | |||
| eos_r_token = '<eos_r>' | |||
| @property | |||
| def bot_id(self): | |||
| return 0 | |||
| @property | |||
| def user_id(self): | |||
| return 1 | |||
| @property | |||
| def vocab_size(self): | |||
| return self.tokenizer.vocab_size | |||
| @property | |||
| def num_specials(self): | |||
| return len(self.tokenizer.special_tokens) | |||
| @property | |||
| def pad_id(self): | |||
| return self.tokenizer.convert_tokens_to_ids([self.pad_token])[0] | |||
| @property | |||
| def bos_id(self): | |||
| return self.tokenizer.convert_tokens_to_ids([self.bos_token])[0] | |||
| @property | |||
| def eos_id(self): | |||
| return self.tokenizer.convert_tokens_to_ids([self.eos_token])[0] | |||
| @property | |||
| def unk_id(self): | |||
| return self.tokenizer.convert_tokens_to_ids([self.unk_token])[0] | |||
| @property | |||
| def sos_u_id(self): | |||
| return self.tokenizer.convert_tokens_to_ids([self.sos_u_token])[0] | |||
| @property | |||
| def eos_u_id(self): | |||
| return self.tokenizer.convert_tokens_to_ids([self.eos_u_token])[0] | |||
| @property | |||
| def sos_b_id(self): | |||
| return self.tokenizer.convert_tokens_to_ids([self.sos_b_token])[0] | |||
| @property | |||
| def eos_b_id(self): | |||
| return self.tokenizer.convert_tokens_to_ids([self.eos_b_token])[0] | |||
| @property | |||
| def sos_db_id(self): | |||
| return self.tokenizer.convert_tokens_to_ids([self.sos_db_token])[0] | |||
| @property | |||
| def eos_db_id(self): | |||
| return self.tokenizer.convert_tokens_to_ids([self.eos_db_token])[0] | |||
| @property | |||
| def sos_a_id(self): | |||
| return self.tokenizer.convert_tokens_to_ids([self.sos_a_token])[0] | |||
| @property | |||
| def eos_a_id(self): | |||
| return self.tokenizer.convert_tokens_to_ids([self.eos_a_token])[0] | |||
| @property | |||
| def sos_r_id(self): | |||
| return self.tokenizer.convert_tokens_to_ids([self.sos_r_token])[0] | |||
| @property | |||
| def eos_r_id(self): | |||
| return self.tokenizer.convert_tokens_to_ids([self.eos_r_token])[0] | |||
| @property | |||
| def sos_d_id(self): | |||
| return self.tokenizer.convert_tokens_to_ids([self.sos_d_token])[0] | |||
| @property | |||
| def eos_d_id(self): | |||
| return self.tokenizer.convert_tokens_to_ids([self.eos_d_token])[0] | |||
| def __init__(self, config): | |||
| self.gpu = 0 | |||
| self.tokenizer = None | |||
| self.vocab = None | |||
| self.db = None | |||
| self.set_stats = {} | |||
| self.prompt_num_for_understand = config.BPETextField.prompt_num_for_understand | |||
| self.prompt_num_for_policy = config.BPETextField.prompt_num_for_policy | |||
| self.understand_tokens = ontology.get_understand_tokens( | |||
| self.prompt_num_for_understand) | |||
| self.policy_tokens = ontology.get_policy_tokens( | |||
| self.prompt_num_for_policy) | |||
| self.with_query_bow = config.BPETextField.with_query_bow | |||
| self.understand = config.BPETextField.understand | |||
| self.policy = config.BPETextField.policy | |||
| self.batch_size = config.Trainer.batch_size | |||
| self.filtered = config.BPETextField.filtered | |||
| self.max_len = config.BPETextField.max_len | |||
| self.min_utt_len = config.BPETextField.min_utt_len | |||
| self.max_utt_len = config.BPETextField.max_utt_len | |||
| self.min_ctx_turn = config.BPETextField.min_ctx_turn | |||
| self.max_ctx_turn = config.BPETextField.max_ctx_turn - 1 # subtract reply turn | |||
| self.use_true_prev_bspn = config.Generator.use_true_prev_bspn | |||
| self.use_true_prev_aspn = config.Generator.use_true_prev_aspn | |||
| self.use_true_db_pointer = config.Generator.use_true_db_pointer | |||
| self.use_true_prev_resp = config.Generator.use_true_prev_resp | |||
| self.use_true_curr_bspn = config.Generator.use_true_curr_bspn | |||
| self.use_true_curr_aspn = config.Generator.use_true_curr_aspn | |||
| self.use_all_previous_context = config.Generator.use_all_previous_context | |||
| self.use_true_bspn_for_ctr_eval = config.Generator.use_true_bspn_for_ctr_eval | |||
| self.use_true_domain_for_ctr_eval = config.Generator.use_true_domain_for_ctr_eval | |||
| def collate_fn_multi_turn(self, samples): | |||
| batch_size = len(samples) | |||
| batch = {} | |||
| src = [sp['src'][-self.max_ctx_turn:] for sp in samples] | |||
| query_token, src_token, src_pos, src_turn, src_role = [], [], [], [], [] | |||
| for utts in src: | |||
| query_token.append(utts[-1]) | |||
| utt_lens = [len(utt) for utt in utts] | |||
| # Token ids | |||
| src_token.append(list(chain(*utts))[-self.max_len:]) | |||
| # Position ids | |||
| pos = [list(range(utt_len)) for utt_len in utt_lens] | |||
| src_pos.append(list(chain(*pos))[-self.max_len:]) | |||
| # Turn ids | |||
| turn = [[len(utts) - i] * l for i, l in enumerate(utt_lens)] | |||
| src_turn.append(list(chain(*turn))[-self.max_len:]) | |||
| # Role ids | |||
| role = [ | |||
| [self.bot_id if (len(utts) - i) % 2 == 0 else self.user_id] * l | |||
| for i, l in enumerate(utt_lens) | |||
| ] | |||
| src_role.append(list(chain(*role))[-self.max_len:]) | |||
| # src sequence and tgt sequence should be padded separately,to make sure the first word is aligned | |||
| src_token = list2np(src_token, padding=self.pad_id) | |||
| src_pos = list2np(src_pos, padding=self.pad_id) | |||
| src_turn = list2np(src_turn, padding=self.pad_id) | |||
| src_role = list2np(src_role, padding=self.pad_id) | |||
| batch['src_token'] = src_token | |||
| batch['src_pos'] = src_pos | |||
| batch['src_type'] = src_role | |||
| batch['src_turn'] = src_turn | |||
| batch['src_mask'] = (src_token != self.pad_id).astype('int64') | |||
| if self.with_query_bow: | |||
| query_token = list2np(query_token, padding=self.pad_id) | |||
| batch['query_token'] = query_token | |||
| batch['query_mask'] = (query_token != self.pad_id).astype('int64') | |||
| if self.understand_ids and self.understand: | |||
| understand = [self.understand_ids for _ in samples] | |||
| understand_token = np.array(understand).astype('int64') | |||
| batch['understand_token'] = understand_token | |||
| batch['understand_mask'] = \ | |||
| (understand_token != self.pad_id).astype('int64') | |||
| if self.policy_ids and self.policy: | |||
| policy = [self.policy_ids for _ in samples] | |||
| policy_token = np.array(policy).astype('int64') | |||
| batch['policy_token'] = policy_token | |||
| batch['policy_mask'] = \ | |||
| (policy_token != self.pad_id).astype('int64') | |||
| if 'tgt' in samples[0]: | |||
| tgt = [sp['tgt'] for sp in samples] | |||
| # Token ids & Label ids | |||
| tgt_token = list2np(tgt, padding=self.pad_id) | |||
| # Position ids | |||
| tgt_pos = np.zeros_like(tgt_token) | |||
| tgt_pos[:] = np.arange(tgt_token.shape[1], dtype=tgt_token.dtype) | |||
| # Turn ids | |||
| tgt_turn = np.zeros_like(tgt_token) | |||
| # Role ids | |||
| tgt_role = np.full_like(tgt_token, self.bot_id) | |||
| batch['tgt_token'] = tgt_token | |||
| batch['tgt_pos'] = tgt_pos | |||
| batch['tgt_type'] = tgt_role | |||
| batch['tgt_turn'] = tgt_turn | |||
| batch['tgt_mask'] = (tgt_token != self.pad_id).astype('int64') | |||
| return batch, batch_size | |||
| def _bucket_by_turn(self, encoded_data): | |||
| turn_bucket = {} | |||
| for dial in encoded_data: | |||
| turn_len = len(dial) | |||
| if turn_len not in turn_bucket: | |||
| turn_bucket[turn_len] = [] | |||
| turn_bucket[turn_len].append(dial) | |||
| return OrderedDict(sorted(turn_bucket.items(), key=lambda i: i[0])) | |||
| def _construct_mini_batch(self, data): | |||
| all_batches = [] | |||
| batch = [] | |||
| for dial in data: | |||
| batch.append(dial) | |||
| if len(batch) == self.batch_size: | |||
| # print('batch size: %d, batch num +1'%(len(batch))) | |||
| all_batches.append(batch) | |||
| batch = [] | |||
| # if remainder > 1/2 batch_size, just put them in the previous batch, otherwise form a new batch | |||
| # print('last batch size: %d, batch num +1'%(len(batch))) | |||
| # if (len(batch) % len(cfg.cuda_device)) != 0: | |||
| # batch = batch[:-(len(batch) % len(cfg.cuda_device))] | |||
| # TODO deal with deleted data | |||
| if self.gpu <= 1: | |||
| if len(batch) > 0.5 * self.batch_size: | |||
| all_batches.append(batch) | |||
| elif len(all_batches): | |||
| all_batches[-1].extend(batch) | |||
| else: | |||
| all_batches.append(batch) | |||
| return all_batches | |||
| def transpose_batch(self, batch): | |||
| dial_batch = [] | |||
| turn_num = len(batch[0]) | |||
| for turn in range(turn_num): | |||
| turn_l = {} | |||
| for dial in batch: | |||
| this_turn = dial[turn] | |||
| for k in this_turn: | |||
| if k not in turn_l: | |||
| turn_l[k] = [] | |||
| turn_l[k].append(this_turn[k]) | |||
| dial_batch.append(turn_l) | |||
| return dial_batch | |||
| def get_eval_data(self, set_name='dev'): | |||
| name_to_set = {'train': self.train, 'test': self.test, 'dev': self.dev} | |||
| dial = name_to_set[set_name] | |||
| if set_name not in self.set_stats: | |||
| self.set_stats[set_name] = {} | |||
| num_turns = 0 | |||
| num_dials = len(dial) | |||
| for d in dial: | |||
| num_turns += len(d) | |||
| self.set_stats[set_name]['num_turns'] = num_turns | |||
| self.set_stats[set_name]['num_dials'] = num_dials | |||
| return dial | |||
| def get_nontranspose_data_iterator(self, all_batches): | |||
| for i, batch in enumerate(all_batches): | |||
| yield batch | |||
| def get_data_iterator(self, all_batches): | |||
| for i, batch in enumerate(all_batches): | |||
| yield self.transpose_batch(batch) | |||
| class MultiWOZBPETextField(BPETextField): | |||
| def __init__(self, model_dir, config): | |||
| super(MultiWOZBPETextField, self).__init__(config) | |||
| import spacy | |||
| self.nlp = spacy.load('en_core_web_sm') | |||
| self.db = MultiWozDB( | |||
| model_dir, { | |||
| 'attraction': 'db/attraction_db_processed.json', | |||
| 'hospital': 'db/hospital_db_processed.json', | |||
| 'hotel': 'db/hotel_db_processed.json', | |||
| 'police': 'db/police_db_processed.json', | |||
| 'restaurant': 'db/restaurant_db_processed.json', | |||
| 'taxi': 'db/taxi_db_processed.json', | |||
| 'train': 'db/train_db_processed.json', | |||
| }) | |||
| self._build_vocab(model_dir) | |||
| special_tokens = [ | |||
| self.pad_token, self.bos_token, self.eos_token, self.unk_token | |||
| ] | |||
| special_tokens.extend(self.add_sepcial_tokens()) | |||
| self.tokenizer = Tokenizer( | |||
| vocab_path=os.path.join(model_dir, 'vocab.txt'), | |||
| special_tokens=special_tokens, | |||
| tokenizer_type=config.BPETextField.tokenizer_type) | |||
| self.understand_ids = self.tokenizer.convert_tokens_to_ids( | |||
| self.understand_tokens) | |||
| self.policy_ids = self.tokenizer.convert_tokens_to_ids( | |||
| self.policy_tokens) | |||
| return | |||
| def get_ids(self, data: str): | |||
| result = [self.sos_u_id] + self.tokenizer.convert_tokens_to_ids( | |||
| self.tokenizer.tokenize( | |||
| self._get_convert_str(data))) + [self.eos_u_id] | |||
| return result | |||
| def inverse_transpose_turn(self, turn_list): | |||
| """ | |||
| eval, one dialog at a time | |||
| """ | |||
| dialogs = {} | |||
| turn_num = len(turn_list) | |||
| dial_id = turn_list[0]['dial_id'] | |||
| dialogs[dial_id] = [] | |||
| for turn_idx in range(turn_num): | |||
| dial_turn = {} | |||
| turn = turn_list[turn_idx] | |||
| for key, value in turn.items(): | |||
| if key == 'dial_id': | |||
| continue | |||
| if key == 'pointer' and self.db is not None: | |||
| turn_domain = turn['turn_domain'][-1] | |||
| value = self.db.pointerBack(value, turn_domain) | |||
| dial_turn[key] = value | |||
| dialogs[dial_id].append(dial_turn) | |||
| return dialogs | |||
| def inverse_transpose_batch(self, turn_batch_list): | |||
| """ | |||
| :param turn_batch_list: list of transpose dial batch | |||
| """ | |||
| dialogs = {} | |||
| total_turn_num = len(turn_batch_list) | |||
| # initialize | |||
| for idx_in_batch, dial_id in enumerate(turn_batch_list[0]['dial_id']): | |||
| dialogs[dial_id] = [] | |||
| for turn_n in range(total_turn_num): | |||
| dial_turn = {} | |||
| turn_batch = turn_batch_list[turn_n] | |||
| for key, v_list in turn_batch.items(): | |||
| if key == 'dial_id': | |||
| continue | |||
| value = v_list[idx_in_batch] | |||
| if key == 'pointer' and self.db is not None: | |||
| turn_domain = turn_batch['turn_domain'][idx_in_batch][ | |||
| -1] | |||
| value = self.db.pointerBack(value, turn_domain) | |||
| dial_turn[key] = value | |||
| dialogs[dial_id].append(dial_turn) | |||
| return dialogs | |||
| def get_batches(self, set_name): | |||
| """ | |||
| compute dataset stats. | |||
| """ | |||
| global dia_count | |||
| log_str = '' | |||
| name_to_set = {'train': self.train, 'test': self.test, 'dev': self.dev} | |||
| dial = name_to_set[set_name] | |||
| turn_bucket = self._bucket_by_turn(dial) | |||
| # self._shuffle_turn_bucket(turn_bucket) | |||
| all_batches = [] | |||
| if set_name not in self.set_stats: | |||
| self.set_stats[set_name] = {} | |||
| num_training_steps = 0 | |||
| num_turns = 0 | |||
| num_dials = 0 | |||
| for k in turn_bucket: | |||
| if set_name != 'test' and k == 1 or k >= 17: | |||
| continue | |||
| batches = self._construct_mini_batch(turn_bucket[k]) | |||
| try: | |||
| log_str += 'turn num:%d, dial num: %d, batch num: %d last batch len: %d\n' % ( | |||
| k, len(turn_bucket[k]), len(batches), len(batches[-1])) | |||
| except Exception: | |||
| log_str += 'turn num:%d, dial num: %d, batch num: %d last batch len: %d\n' % ( | |||
| k, len(turn_bucket[k]), len(batches), 0.0) | |||
| # print("turn num:%d, dial num:v%d, batch num: %d, "%(k, len(turn_bucket[k]), len(batches))) | |||
| num_training_steps += k * len(batches) | |||
| num_turns += k * len(turn_bucket[k]) | |||
| num_dials += len(turn_bucket[k]) | |||
| all_batches += batches | |||
| log_str += 'total batch num: %d\n' % len(all_batches) | |||
| # print('total batch num: %d'%len(all_batches)) | |||
| # print('dialog count: %d'%dia_count) | |||
| # return all_batches | |||
| # log stats | |||
| # logging.info(log_str) | |||
| # cfg.num_training_steps = num_training_steps * cfg.epoch_num | |||
| self.set_stats[set_name][ | |||
| 'num_training_steps_per_epoch'] = num_training_steps # turn-level steps | |||
| self.set_stats[set_name]['num_turns'] = num_turns | |||
| self.set_stats[set_name]['num_dials'] = num_dials | |||
| if set_name == 'train': | |||
| random.shuffle(all_batches) | |||
| return all_batches | |||
| def add_sepcial_tokens(self): | |||
| """ | |||
| add special tokens to gpt tokenizer | |||
| serves a similar role of Vocab.construt() | |||
| make a dict of special tokens | |||
| """ | |||
| special_tokens = [] | |||
| prompt_tokens = self.understand_tokens + self.policy_tokens | |||
| special_tokens.extend( | |||
| ontology.get_special_tokens(other_tokens=prompt_tokens)) | |||
| for word in ontology.all_domains + ['general']: | |||
| word = '[' + word + ']' | |||
| special_tokens.append(word) | |||
| for word in ontology.all_acts: | |||
| word = '[' + word + ']' | |||
| special_tokens.append(word) | |||
| for word in self.vocab._word2idx.keys(): | |||
| if word.startswith('[value_') and word.endswith(']'): | |||
| special_tokens.append(word) | |||
| return special_tokens | |||
| def _build_vocab(self, model_dir: str): | |||
| self.vocab = utils.MultiWOZVocab(3000) | |||
| vp = os.path.join('{}/vocab'.format(model_dir)) | |||
| self.vocab.load_vocab(vp) | |||
| return self.vocab.vocab_size | |||
| def _get_convert_str(self, sent): | |||
| assert isinstance(sent, str) | |||
| return ' '.join([ | |||
| self.tokenizer.spec_convert_dict.get(tok, tok) | |||
| for tok in sent.split() | |||
| ]) | |||
| def bspan_to_DBpointer(self, bspan, turn_domain): | |||
| constraint_dict = self.bspan_to_constraint_dict(bspan) | |||
| # print(constraint_dict) | |||
| matnums = self.db.get_match_num(constraint_dict) | |||
| match_dom = turn_domain[0] if len(turn_domain) == 1 else turn_domain[1] | |||
| match_dom = match_dom[1:-1] if match_dom.startswith('[') else match_dom | |||
| match = matnums[match_dom] | |||
| # vector = self.db.addDBPointer(match_dom, match) | |||
| vector = self.db.addDBIndicator(match_dom, match) | |||
| return vector | |||
| def bspan_to_constraint_dict(self, bspan, bspn_mode='bspn'): | |||
| """ | |||
| ['[hotel]', 'pricerange', 'cheap', 'type', 'hotel'] -> {'hotel': {'pricerange': 'cheap', 'type': 'hotel'}} | |||
| """ | |||
| bspan = bspan.split() if isinstance(bspan, str) else bspan | |||
| constraint_dict = {} | |||
| domain = None | |||
| conslen = len(bspan) | |||
| for idx, cons in enumerate(bspan): | |||
| cons = self.vocab.decode(cons) if type(cons) is not str else cons | |||
| if cons == '<eos_b>': | |||
| break | |||
| if '[' in cons: | |||
| if cons[1:-1] not in ontology.all_domains: | |||
| continue | |||
| domain = cons[1:-1] | |||
| elif cons in ontology.get_slot: | |||
| if domain is None: | |||
| continue | |||
| if cons == 'people': | |||
| # handle confusion of value name "people's portraits..." and slot people | |||
| try: | |||
| ns = bspan[idx + 1] | |||
| ns = self.vocab.decode(ns) if type( | |||
| ns) is not str else ns | |||
| if ns == "'s": | |||
| continue | |||
| except Exception: | |||
| continue | |||
| if not constraint_dict.get(domain): | |||
| constraint_dict[domain] = {} | |||
| if bspn_mode == 'bsdx': | |||
| constraint_dict[domain][cons] = 1 | |||
| continue | |||
| vidx = idx + 1 | |||
| if vidx == conslen: | |||
| break | |||
| vt_collect = [] | |||
| vt = bspan[vidx] | |||
| vt = self.vocab.decode(vt) if type(vt) is not str else vt | |||
| while vidx < conslen and vt != '<eos_b>' and '[' not in vt and vt not in ontology.get_slot: | |||
| vt_collect.append(vt) | |||
| vidx += 1 | |||
| if vidx == conslen: | |||
| break | |||
| vt = bspan[vidx] | |||
| vt = self.vocab.decode(vt) if type(vt) is not str else vt | |||
| if vt_collect: | |||
| constraint_dict[domain][cons] = ' '.join(vt_collect) | |||
| return constraint_dict | |||
| def convert_batch_turn(self, turn_batch, pv_batch, first_turn=False): | |||
| """ | |||
| convert the current and the last turn | |||
| concat [U_0,R_0,...,U_{t-1}, R_{t-1}, U_t, B_t, A_t, R_t] | |||
| firts turn: [U_t, B_t, A_t, R_t] | |||
| try: [user, bspn, db, aspn, resp] | |||
| """ | |||
| inputs = [] | |||
| if first_turn: | |||
| batch_zipped = zip(turn_batch['user'], turn_batch['bspn'], | |||
| turn_batch['db'], turn_batch['aspn'], | |||
| turn_batch['resp']) | |||
| for u, b, db, a, r in batch_zipped: | |||
| if self.use_true_curr_bspn: | |||
| src = [u + b + db] | |||
| tgt = a + r | |||
| else: | |||
| src = [u] | |||
| tgt = b + db + a + r | |||
| inputs.append({'src': src, 'tgt': tgt}) | |||
| pv = [src[-1], tgt] | |||
| pv_batch.append(pv) | |||
| else: | |||
| batch_zipped = zip(pv_batch, turn_batch['user'], | |||
| turn_batch['bspn'], turn_batch['db'], | |||
| turn_batch['aspn'], turn_batch['resp']) | |||
| for i, (pv, u, b, db, a, r) in enumerate(batch_zipped): | |||
| if self.use_true_curr_bspn: | |||
| src = pv + [u + b + db] | |||
| tgt = a + r | |||
| else: | |||
| src = pv + [u] | |||
| tgt = b + db + a + r | |||
| inputs.append({'src': src, 'tgt': tgt}) | |||
| pv = [src[-1], tgt] | |||
| pv_batch[i].extend(pv) | |||
| return inputs, pv_batch | |||
| def wrap_result_lm(self, result_dict, eos_syntax=None): | |||
| results = [] | |||
| eos_syntax = ontology.eos_tokens if not eos_syntax else eos_syntax | |||
| sos_syntax = ontology.sos_tokens | |||
| # ground truth bs, as, ds.. generate response | |||
| field = [ | |||
| 'dial_id', 'turn_num', 'user', 'bspn_gen', 'bsdx', 'resp_gen', | |||
| 'resp', 'aspn_gen', 'aspn', 'dspn_gen', 'dspn', 'bspn', 'pointer', | |||
| 'qspn_gen', 'qspn' | |||
| ] | |||
| for dial_id, turns in result_dict.items(): | |||
| entry = {'dial_id': dial_id, 'trun_num': len(turns)} | |||
| for f in field[2:]: | |||
| entry[f] = '' # TODO ??? | |||
| results.append(entry) | |||
| for turn_idx, turn in enumerate(turns): | |||
| entry = {'dial_id': dial_id} | |||
| for key in field: | |||
| if key in ['dial_id']: | |||
| continue | |||
| v = turn.get(key, '') | |||
| if key == 'turn_domain': | |||
| v = ' '.join(v) | |||
| if key in eos_syntax and v != '': | |||
| # remove eos tokens | |||
| v = self.tokenizer.decode(v) | |||
| v = v.split() | |||
| # remove eos/sos in span | |||
| if eos_syntax[key] in v: | |||
| v.remove(eos_syntax[key]) | |||
| if sos_syntax[key] in v: | |||
| v.remove(sos_syntax[key]) | |||
| v = ' '.join(v) | |||
| else: | |||
| pass # v = v | |||
| entry[key] = v | |||
| results.append(entry) | |||
| return results, field | |||
| def convert_turn_eval(self, turn, pv_turn, first_turn=False): | |||
| """ | |||
| input: [all previous ubar, U_t, B_t, A_t] predict R_t | |||
| firts turn: [U_t, B_t, A_t] predict R_t | |||
| regarding the context, all previous ubar is too slow, try the previous ubar | |||
| """ | |||
| inputs = {} | |||
| context_list = [] | |||
| prompt_id = None | |||
| if self.use_true_curr_bspn: | |||
| if self.use_true_curr_aspn: # only predict resp | |||
| context_list = ['user', 'bspn', 'db', 'aspn'] | |||
| prompt_id = self.sos_r_id | |||
| else: # predicted aspn | |||
| context_list = ['user', 'bspn', 'db'] | |||
| prompt_id = self.sos_a_id | |||
| else: # predict bspn aspn resp. db are not predicted. this part tbd. | |||
| context_list = ['user'] | |||
| prompt_id = self.sos_b_id | |||
| if first_turn: | |||
| context = [] | |||
| for c in context_list: | |||
| context += turn[c] | |||
| inputs['src'] = [context] | |||
| inputs['labels'] = [context] | |||
| else: | |||
| context = [] | |||
| for c in context_list: | |||
| context += turn[c] | |||
| if self.use_true_curr_bspn: | |||
| pv_context = pv_turn['labels'] + [ | |||
| pv_turn['aspn'] + pv_turn['resp'] | |||
| ] | |||
| else: | |||
| pv_info = pv_turn['bspn'] + pv_turn['db'] + pv_turn[ | |||
| 'aspn'] + pv_turn['resp'] | |||
| pv_context = pv_turn['labels'] + [pv_info] | |||
| # prompt response, add sos_r | |||
| inputs['src'] = pv_context + [context] | |||
| if self.use_all_previous_context: | |||
| inputs['labels'] = pv_context + [ | |||
| context | |||
| ] # use all previous ubar history | |||
| else: | |||
| inputs['labels'] = [context] # use previous turn | |||
| return inputs, prompt_id | |||
| @@ -0,0 +1,668 @@ | |||
| from __future__ import (absolute_import, division, print_function, | |||
| unicode_literals) | |||
| import collections | |||
| import logging | |||
| import os | |||
| import sys | |||
| import unicodedata | |||
| import json | |||
| import regex as re | |||
| def clean_string(string): | |||
| replace_mp = { | |||
| ' - ': '-', | |||
| " ' ": "'", | |||
| " n't": "n't", | |||
| " 'm": "'m", | |||
| ' do not': " don't", | |||
| " 's": "'s", | |||
| " 've": "'ve", | |||
| " 're": "'re" | |||
| } | |||
| for k, v in replace_mp.items(): | |||
| string = string.replace(k, v) | |||
| return string | |||
| class Tokenizer(object): | |||
| def __init__(self, vocab_path, special_tokens=[], tokenizer_type='Bert'): | |||
| self.tokenizer_type = tokenizer_type | |||
| if tokenizer_type == 'Bert': | |||
| self.spec_convert_dict = { | |||
| '[BOS]': '[unused0]', | |||
| '[EOS]': '[unused1]' | |||
| } | |||
| for token in special_tokens: | |||
| if token not in self.spec_convert_dict and token not in [ | |||
| '[PAD]', '[UNK]' | |||
| ]: | |||
| self.spec_convert_dict[ | |||
| token] = f'[unused{len(self.spec_convert_dict)}]' | |||
| self.spec_revert_dict = { | |||
| v: k | |||
| for k, v in self.spec_convert_dict.items() | |||
| } | |||
| special_tokens = [ | |||
| self.spec_convert_dict.get(tok, tok) for tok in special_tokens | |||
| ] | |||
| self.special_tokens = ('[UNK]', '[SEP]', '[PAD]', '[CLS]', | |||
| '[MASK]') | |||
| self.special_tokens += tuple(x for x in special_tokens | |||
| if x not in self.special_tokens) | |||
| self._tokenizer = BertTokenizer( | |||
| vocab_path, never_split=self.special_tokens) | |||
| for tok in self.special_tokens: | |||
| assert tok in self._tokenizer.vocab, f"special token '{tok}' is not in the vocabulary" | |||
| self.vocab_size = len(self._tokenizer.vocab) | |||
| elif tokenizer_type == 'GPT2': | |||
| self.spec_convert_dict = {'[UNK]': '<unk>'} | |||
| self.spec_revert_dict = { | |||
| v: k | |||
| for k, v in self.spec_convert_dict.items() | |||
| } | |||
| special_tokens = [ | |||
| tok for tok in special_tokens | |||
| if tok not in self.spec_convert_dict | |||
| ] | |||
| vocab_file = os.path.join(vocab_path, 'vocab.json') | |||
| merges_file = os.path.join(vocab_path, 'merges.txt') | |||
| self._tokenizer = GPT2Tokenizer( | |||
| vocab_file, merges_file, special_tokens=special_tokens) | |||
| self.num_specials = len(special_tokens) | |||
| self.vocab_size = len(self._tokenizer) | |||
| else: | |||
| raise ValueError | |||
| def tokenize(self, text): | |||
| return self._tokenizer.tokenize(text) | |||
| def convert_tokens_to_ids(self, tokens): | |||
| if self.tokenizer_type == 'Bert': | |||
| tokens = [self.spec_convert_dict.get(tok, tok) for tok in tokens] | |||
| ids = self._tokenizer.convert_tokens_to_ids(tokens) | |||
| return ids | |||
| else: | |||
| tokens = [self.spec_convert_dict.get(tok, tok) for tok in tokens] | |||
| ids = self._tokenizer.convert_tokens_to_ids(tokens) | |||
| ids = [(i + self.num_specials) % self.vocab_size for i in ids] | |||
| return ids | |||
| def convert_ids_to_tokens(self, ids): | |||
| if self.tokenizer_type == 'Bert': | |||
| tokens = self._tokenizer.convert_ids_to_tokens(ids) | |||
| tokens = [self.spec_revert_dict.get(tok, tok) for tok in tokens] | |||
| return tokens | |||
| else: | |||
| ids = [(i - self.num_specials) % self.vocab_size for i in ids] | |||
| tokens = self._tokenizer.convert_ids_to_tokens(ids) | |||
| tokens = [self.spec_revert_dict.get(tok, tok) for tok in tokens] | |||
| return tokens | |||
| def decode(self, ids, ignore_tokens=[]): | |||
| tokens = self.convert_ids_to_tokens(ids) | |||
| if len(ignore_tokens) > 0: | |||
| ignore_tokens = set(ignore_tokens) | |||
| tokens = [tok for tok in tokens if tok not in ignore_tokens] | |||
| if self.tokenizer_type == 'Bert': | |||
| string = ' '.join(tokens).replace(' ##', '') | |||
| else: | |||
| string = ''.join(tokens) | |||
| string = bytearray([ | |||
| self._tokenizer.byte_decoder[c] for c in string | |||
| ]).decode('utf-8') | |||
| string = clean_string(string) | |||
| return string | |||
| # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Tokenization classes.""" | |||
| logger = logging.getLogger(__name__) | |||
| def load_vocab(vocab_file): | |||
| """Loads a vocabulary file into a dictionary.""" | |||
| vocab = collections.OrderedDict() | |||
| index = 0 | |||
| with open(vocab_file, 'r', encoding='utf-8') as reader: | |||
| while True: | |||
| token = reader.readline() | |||
| if not token: | |||
| break | |||
| token = token.strip() | |||
| vocab[token] = index | |||
| index += 1 | |||
| return vocab | |||
| def whitespace_tokenize(text): | |||
| """Runs basic whitespace cleaning and splitting on a piece of text.""" | |||
| text = text.strip() | |||
| if not text: | |||
| return [] | |||
| tokens = text.split() | |||
| return tokens | |||
| class BertTokenizer(object): | |||
| """Runs end-to-end tokenization: punctuation splitting + wordpiece""" | |||
| def __init__(self, | |||
| vocab_file, | |||
| do_lower_case=True, | |||
| max_len=None, | |||
| do_basic_tokenize=True, | |||
| never_split=('[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]')): | |||
| """Constructs a BertTokenizer. | |||
| Args: | |||
| vocab_file: Path to a one-wordpiece-per-line vocabulary file | |||
| do_lower_case: Whether to lower case the input | |||
| Only has an effect when do_wordpiece_only=False | |||
| do_basic_tokenize: Whether to do basic tokenization before wordpiece. | |||
| max_len: An artificial maximum length to truncate tokenized sequences to; | |||
| Effective maximum length is always the minimum of this | |||
| value (if specified) and the underlying BERT model's | |||
| sequence length. | |||
| never_split: List of tokens which will never be split during tokenization. | |||
| Only has an effect when do_wordpiece_only=False | |||
| """ | |||
| if not os.path.isfile(vocab_file): | |||
| raise ValueError( | |||
| "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " | |||
| 'model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`' | |||
| .format(vocab_file)) | |||
| self.vocab = load_vocab(vocab_file) | |||
| self.ids_to_tokens = collections.OrderedDict([ | |||
| (ids, tok) for tok, ids in self.vocab.items() | |||
| ]) | |||
| self.do_basic_tokenize = do_basic_tokenize | |||
| if do_basic_tokenize: | |||
| self.basic_tokenizer = BasicTokenizer( | |||
| do_lower_case=do_lower_case, never_split=never_split) | |||
| self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) | |||
| self.max_len = max_len if max_len is not None else int(1e12) | |||
| def tokenize(self, text): | |||
| split_tokens = [] | |||
| if self.do_basic_tokenize: | |||
| for token in self.basic_tokenizer.tokenize(text): | |||
| for sub_token in self.wordpiece_tokenizer.tokenize(token): | |||
| split_tokens.append(sub_token) | |||
| else: | |||
| split_tokens = self.wordpiece_tokenizer.tokenize(text) | |||
| return split_tokens | |||
| def convert_tokens_to_ids(self, tokens): | |||
| """Converts a sequence of tokens into ids using the vocab.""" | |||
| ids = [] | |||
| for token in tokens: | |||
| ids.append(self.vocab[token]) | |||
| if len(ids) > self.max_len: | |||
| logger.warning( | |||
| 'Token indices sequence length is longer than the specified maximum ' | |||
| ' sequence length for this BERT model ({} > {}). Running this' | |||
| ' sequence through BERT will result in indexing errors'.format( | |||
| len(ids), self.max_len)) | |||
| return ids | |||
| def convert_ids_to_tokens(self, ids): | |||
| """Converts a sequence of ids in wordpiece tokens using the vocab.""" | |||
| tokens = [] | |||
| for i in ids: | |||
| tokens.append(self.ids_to_tokens[i]) | |||
| return tokens | |||
| class BasicTokenizer(object): | |||
| """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" | |||
| def __init__(self, | |||
| do_lower_case=True, | |||
| never_split=('[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]')): | |||
| """Constructs a BasicTokenizer. | |||
| Args: | |||
| do_lower_case: Whether to lower case the input. | |||
| """ | |||
| self.do_lower_case = do_lower_case | |||
| self.never_split = never_split | |||
| def tokenize(self, text): | |||
| """Tokenizes a piece of text.""" | |||
| text = self._clean_text(text) | |||
| # This was added on November 1st, 2018 for the multilingual and Chinese | |||
| # models. This is also applied to the English models now, but it doesn't | |||
| # matter since the English models were not trained on any Chinese data | |||
| # and generally don't have any Chinese data in them (there are Chinese | |||
| # characters in the vocabulary because Wikipedia does have some Chinese | |||
| # words in the English Wikipedia.). | |||
| text = self._tokenize_chinese_chars(text) | |||
| orig_tokens = whitespace_tokenize(text) | |||
| split_tokens = [] | |||
| for token in orig_tokens: | |||
| if self.do_lower_case and token not in self.never_split: | |||
| token = token.lower() | |||
| token = self._run_strip_accents(token) | |||
| split_tokens.extend(self._run_split_on_punc(token)) | |||
| output_tokens = whitespace_tokenize(' '.join(split_tokens)) | |||
| return output_tokens | |||
| def _run_strip_accents(self, text): | |||
| """Strips accents from a piece of text.""" | |||
| text = unicodedata.normalize('NFD', text) | |||
| output = [] | |||
| for char in text: | |||
| cat = unicodedata.category(char) | |||
| if cat == 'Mn': | |||
| continue | |||
| output.append(char) | |||
| return ''.join(output) | |||
| def _run_split_on_punc(self, text): | |||
| """Splits punctuation on a piece of text.""" | |||
| if text in self.never_split: | |||
| return [text] | |||
| chars = list(text) | |||
| i = 0 | |||
| start_new_word = True | |||
| output = [] | |||
| while i < len(chars): | |||
| char = chars[i] | |||
| if _is_punctuation(char): | |||
| output.append([char]) | |||
| start_new_word = True | |||
| else: | |||
| if start_new_word: | |||
| output.append([]) | |||
| start_new_word = False | |||
| output[-1].append(char) | |||
| i += 1 | |||
| return [''.join(x) for x in output] | |||
| def _tokenize_chinese_chars(self, text): | |||
| """Adds whitespace around any CJK character.""" | |||
| output = [] | |||
| for char in text: | |||
| cp = ord(char) | |||
| if self._is_chinese_char(cp): | |||
| output.append(' ') | |||
| output.append(char) | |||
| output.append(' ') | |||
| else: | |||
| output.append(char) | |||
| return ''.join(output) | |||
| def _is_chinese_char(self, cp): | |||
| """Checks whether CP is the codepoint of a CJK character.""" | |||
| # This defines a "chinese character" as anything in the CJK Unicode block: | |||
| # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) | |||
| # | |||
| # Note that the CJK Unicode block is NOT all Japanese and Korean characters, | |||
| # despite its name. The modern Korean Hangul alphabet is a different block, | |||
| # as is Japanese Hiragana and Katakana. Those alphabets are used to write | |||
| # space-separated words, so they are not treated specially and handled | |||
| # like the all of the other languages. | |||
| tmp = (cp >= 0x4E00 and cp <= 0x9FFF) | |||
| tmp = tmp or (cp >= 0x3400 and cp <= 0x4DBF) | |||
| tmp = tmp or (cp >= 0x20000 and cp <= 0x2A6DF) | |||
| tmp = tmp or (cp >= 0x2A700 and cp <= 0x2B73F) | |||
| tmp = tmp or (cp >= 0x2B740 and cp <= 0x2B81F) | |||
| tmp = tmp or (cp >= 0x2B820 and cp <= 0x2CEAF) | |||
| tmp = tmp or (cp >= 0xF900 and cp <= 0xFAFF) | |||
| tmp = tmp or (cp >= 0x2F800 and cp <= 0x2FA1F) | |||
| if tmp: | |||
| return True | |||
| return False | |||
| def _clean_text(self, text): | |||
| """Performs invalid character removal and whitespace cleanup on text.""" | |||
| output = [] | |||
| for char in text: | |||
| cp = ord(char) | |||
| if cp == 0 or cp == 0xfffd or _is_control(char): | |||
| continue | |||
| if _is_whitespace(char): | |||
| output.append(' ') | |||
| else: | |||
| output.append(char) | |||
| return ''.join(output) | |||
| class WordpieceTokenizer(object): | |||
| """Runs WordPiece tokenization.""" | |||
| def __init__(self, vocab, unk_token='[UNK]', max_input_chars_per_word=100): | |||
| self.vocab = vocab | |||
| self.unk_token = unk_token | |||
| self.max_input_chars_per_word = max_input_chars_per_word | |||
| def tokenize(self, text): | |||
| """Tokenizes a piece of text into its word pieces. | |||
| This uses a greedy longest-match-first algorithm to perform tokenization | |||
| using the given vocabulary. | |||
| For example: | |||
| input = "unaffable" | |||
| output = ["un", "##aff", "##able"] | |||
| Args: | |||
| text: A single token or whitespace separated tokens. This should have | |||
| already been passed through `BasicTokenizer`. | |||
| Returns: | |||
| A list of wordpiece tokens. | |||
| """ | |||
| output_tokens = [] | |||
| for token in whitespace_tokenize(text): | |||
| chars = list(token) | |||
| if len(chars) > self.max_input_chars_per_word: | |||
| output_tokens.append(self.unk_token) | |||
| continue | |||
| is_bad = False | |||
| start = 0 | |||
| sub_tokens = [] | |||
| while start < len(chars): | |||
| end = len(chars) | |||
| cur_substr = None | |||
| while start < end: | |||
| substr = ''.join(chars[start:end]) | |||
| if start > 0: | |||
| substr = '##' + substr | |||
| if substr in self.vocab: | |||
| cur_substr = substr | |||
| break | |||
| end -= 1 | |||
| if cur_substr is None: | |||
| is_bad = True | |||
| break | |||
| sub_tokens.append(cur_substr) | |||
| start = end | |||
| if is_bad: | |||
| output_tokens.append(self.unk_token) | |||
| else: | |||
| output_tokens.extend(sub_tokens) | |||
| return output_tokens | |||
| def _is_whitespace(char): | |||
| """Checks whether `chars` is a whitespace character.""" | |||
| # \t, \n, and \r are technically contorl characters but we treat them | |||
| # as whitespace since they are generally considered as such. | |||
| if char == ' ' or char == '\t' or char == '\n' or char == '\r': | |||
| return True | |||
| cat = unicodedata.category(char) | |||
| if cat == 'Zs': | |||
| return True | |||
| return False | |||
| def _is_control(char): | |||
| """Checks whether `chars` is a control character.""" | |||
| # These are technically control characters but we count them as whitespace | |||
| # characters. | |||
| if char == '\t' or char == '\n' or char == '\r': | |||
| return False | |||
| cat = unicodedata.category(char) | |||
| if cat.startswith('C'): | |||
| return True | |||
| return False | |||
| def _is_punctuation(char): | |||
| """Checks whether `chars` is a punctuation character.""" | |||
| cp = ord(char) | |||
| # We treat all non-letter/number ASCII as punctuation. | |||
| # Characters such as "^", "$", and "`" are not in the Unicode | |||
| # Punctuation class but we treat them as punctuation anyways, for | |||
| # consistency. | |||
| tmp = (cp >= 33 and cp <= 47) | |||
| tmp = tmp or (cp >= 58 and cp <= 64) | |||
| tmp = tmp or (cp >= 91 and cp <= 96) | |||
| tmp = tmp or (cp >= 123 and cp <= 126) | |||
| if tmp: | |||
| return True | |||
| cat = unicodedata.category(char) | |||
| if cat.startswith('P'): | |||
| return True | |||
| return False | |||
| # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Tokenization classes for OpenAI GPT.""" | |||
| try: | |||
| from functools import lru_cache | |||
| except ImportError: | |||
| # Just a dummy decorator to get the checks to run on python2 | |||
| # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. | |||
| def lru_cache(): | |||
| return lambda func: func | |||
| @lru_cache() | |||
| def bytes_to_unicode(): | |||
| """ | |||
| Returns list of utf-8 byte and a corresponding list of unicode strings. | |||
| The reversible bpe codes work on unicode strings. | |||
| This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. | |||
| When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. | |||
| This is a signficant percentage of your normal, say, 32K bpe vocab. | |||
| To avoid that, we want lookup tables between utf-8 bytes and unicode strings. | |||
| And avoids mapping to whitespace/control characters the bpe code barfs on. | |||
| """ | |||
| _chr = unichr if sys.version_info[0] == 2 else chr | |||
| bs = list(range(ord('!'), | |||
| ord('~') + 1)) + list(range( | |||
| ord('¡'), | |||
| ord('¬') + 1)) + list(range(ord('®'), | |||
| ord('ÿ') + 1)) | |||
| cs = bs[:] | |||
| n = 0 | |||
| for b in range(2**8): | |||
| if b not in bs: | |||
| bs.append(b) | |||
| cs.append(2**8 + n) | |||
| n += 1 | |||
| cs = [_chr(n) for n in cs] | |||
| return dict(zip(bs, cs)) | |||
| def get_pairs(word): | |||
| """Return set of symbol pairs in a word. | |||
| Word is represented as tuple of symbols (symbols being variable-length strings). | |||
| """ | |||
| pairs = set() | |||
| prev_char = word[0] | |||
| for char in word[1:]: | |||
| pairs.add((prev_char, char)) | |||
| prev_char = char | |||
| return pairs | |||
| class GPT2Tokenizer(object): | |||
| """ | |||
| GPT-2 BPE tokenizer. Peculiarities: | |||
| - Byte-level BPE | |||
| """ | |||
| def __init__(self, | |||
| vocab_file, | |||
| merges_file, | |||
| errors='replace', | |||
| special_tokens=None, | |||
| max_len=None): | |||
| self.max_len = max_len if max_len is not None else int(1e12) | |||
| self.encoder = json.load(open(vocab_file)) | |||
| self.decoder = {v: k for k, v in self.encoder.items()} | |||
| self.errors = errors # how to handle errors in decoding | |||
| self.byte_encoder = bytes_to_unicode() | |||
| self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} | |||
| bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] | |||
| bpe_merges = [tuple(merge.split()) for merge in bpe_data] | |||
| self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) | |||
| self.cache = {} | |||
| # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions | |||
| self.pat = re.compile( | |||
| r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" | |||
| ) | |||
| self.special_tokens = {} | |||
| self.special_tokens_decoder = {} | |||
| self.set_special_tokens(special_tokens) | |||
| def __len__(self): | |||
| return len(self.encoder) + len(self.special_tokens) | |||
| def set_special_tokens(self, special_tokens): | |||
| """ Add a list of additional tokens to the encoder. | |||
| The additional tokens are indexed starting from the last index of the | |||
| current vocabulary in the order of the `special_tokens` list. | |||
| """ | |||
| if not special_tokens: | |||
| self.special_tokens = {} | |||
| self.special_tokens_decoder = {} | |||
| return | |||
| self.special_tokens = dict((tok, len(self.encoder) + i) | |||
| for i, tok in enumerate(special_tokens)) | |||
| self.special_tokens_decoder = { | |||
| v: k | |||
| for k, v in self.special_tokens.items() | |||
| } | |||
| logger.info('Special tokens {}'.format(self.special_tokens)) | |||
| def bpe(self, token): | |||
| if token in self.cache: | |||
| return self.cache[token] | |||
| word = tuple(token) | |||
| pairs = get_pairs(word) | |||
| if not pairs: | |||
| return token | |||
| while True: | |||
| bigram = min( | |||
| pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) | |||
| if bigram not in self.bpe_ranks: | |||
| break | |||
| first, second = bigram | |||
| new_word = [] | |||
| i = 0 | |||
| while i < len(word): | |||
| try: | |||
| j = word.index(first, i) | |||
| new_word.extend(word[i:j]) | |||
| i = j | |||
| except Exception: | |||
| new_word.extend(word[i:]) | |||
| break | |||
| if word[i] == first and i < len(word) - 1 and word[ | |||
| i + 1] == second: | |||
| new_word.append(first + second) | |||
| i += 2 | |||
| else: | |||
| new_word.append(word[i]) | |||
| i += 1 | |||
| new_word = tuple(new_word) | |||
| word = new_word | |||
| if len(word) == 1: | |||
| break | |||
| else: | |||
| pairs = get_pairs(word) | |||
| word = ' '.join(word) | |||
| self.cache[token] = word | |||
| return word | |||
| def tokenize(self, text): | |||
| """ Tokenize a string. """ | |||
| bpe_tokens = [] | |||
| for token in re.findall(self.pat, text): | |||
| token = ''.join(self.byte_encoder[ord(b)] for b in token | |||
| if ord(b) in self.byte_encoder) | |||
| if token == '': | |||
| continue | |||
| bpe_tokens.extend( | |||
| bpe_token for bpe_token in self.bpe(token).split(' ')) | |||
| return bpe_tokens | |||
| def convert_tokens_to_ids(self, tokens): | |||
| """ Converts a sequence of tokens into ids using the vocab. """ | |||
| ids = [] | |||
| python_version_3 = isinstance(tokens, str) | |||
| python_version_2 = ( | |||
| sys.version_info[0] == 2 and isinstance(tokens, unicode)) | |||
| if python_version_3 or python_version_2: | |||
| if tokens in self.special_tokens: | |||
| return self.special_tokens[tokens] | |||
| else: | |||
| return self.encoder.get(tokens, 0) | |||
| for token in tokens: | |||
| if token in self.special_tokens: | |||
| ids.append(self.special_tokens[token]) | |||
| else: | |||
| ids.append(self.encoder.get(token, 0)) | |||
| if len(ids) > self.max_len: | |||
| logger.warning( | |||
| 'Token indices sequence length is longer than the specified maximum ' | |||
| ' sequence length for this OpenAI GPT model ({} > {}). Running this' | |||
| ' sequence through the model will result in indexing errors'. | |||
| format(len(ids), self.max_len)) | |||
| return ids | |||
| def convert_ids_to_tokens(self, ids, skip_special_tokens=False): | |||
| """Converts a sequence of ids in BPE tokens using the vocab.""" | |||
| tokens = [] | |||
| for i in ids: | |||
| if i in self.special_tokens_decoder: | |||
| if not skip_special_tokens: | |||
| tokens.append(self.special_tokens_decoder[i]) | |||
| else: | |||
| tokens.append(self.decoder[i]) | |||
| return tokens | |||
| def encode(self, text): | |||
| return self.convert_tokens_to_ids(self.tokenize(text)) | |||
| def decode(self, tokens): | |||
| text = ''.join([self.decoder[token] for token in tokens]) | |||
| text = bytearray([self.byte_decoder[c] for c in text]).decode( | |||
| 'utf-8', errors=self.errors) | |||
| return text | |||
| @@ -0,0 +1,73 @@ | |||
| """ | |||
| MetricsTracker class | |||
| """ | |||
| import math | |||
| from collections import defaultdict | |||
| class MetricsTracker(object): | |||
| """ Tracking metrics. """ | |||
| def __init__(self): | |||
| self.metrics_val = defaultdict(float) # for one batch | |||
| self.metrics_avg = defaultdict(float) # avg batches | |||
| self.num_samples = 0 | |||
| def update(self, metrics, num_samples): | |||
| for key, val in metrics.items(): | |||
| if val is not None: | |||
| val = float(val) # [val] -> val | |||
| self.metrics_val[key] = val | |||
| avg_val = \ | |||
| (self.metrics_avg.get(key, 0) * self.num_samples + val * num_samples) / \ | |||
| (self.num_samples + num_samples) | |||
| self.metrics_avg[key] = avg_val | |||
| self.num_samples += num_samples | |||
| def clear(self): | |||
| self.metrics_val = defaultdict(float) | |||
| self.metrics_avg = defaultdict(float) | |||
| self.num_samples = 0 | |||
| def items(self): | |||
| return self.metrics_avg.items() | |||
| def get(self, name): | |||
| if self.num_samples == 0: | |||
| raise ValueError('There is no data in Metrics.') | |||
| return self.metrics_avg.get(name) | |||
| def state_dict(self): | |||
| return { | |||
| 'metrics_val': self.metrics_val, | |||
| 'metrics_avg': self.metrics_avg, | |||
| 'num_samples': self.num_samples, | |||
| } | |||
| def load_state_dict(self, state_dict): | |||
| self.metrics_val = state_dict['metrics_val'] | |||
| self.metrics_avg = state_dict['metrics_avg'] | |||
| self.num_samples = state_dict['num_samples'] | |||
| def value(self): | |||
| metric_strs = [] | |||
| for key, val in self.metrics_val.items(): | |||
| metric_str = f'{key.upper()}-{val:.3f}' | |||
| metric_strs.append(metric_str) | |||
| if 'token_nll' in self.metrics_val: | |||
| metric_str = f"TOKEN_PPL-{math.exp(self.metrics_val['token_nll']):.3f}" | |||
| metric_strs.append(metric_str) | |||
| metric_strs = ' '.join(metric_strs) | |||
| return metric_strs | |||
| def summary(self): | |||
| metric_strs = [] | |||
| for key, val in self.metrics_avg.items(): | |||
| metric_str = f'{key.upper()}-{val:.3f}' | |||
| metric_strs.append(metric_str) | |||
| if 'token_nll' in self.metrics_avg: | |||
| metric_str = f"TOKEN_PPL-{math.exp(self.metrics_avg['token_nll']):.3f}" | |||
| metric_strs.append(metric_str) | |||
| metric_strs = ' '.join(metric_strs) | |||
| return metric_strs | |||
| @@ -0,0 +1,761 @@ | |||
| """ | |||
| Trainer class. | |||
| """ | |||
| import logging | |||
| import os | |||
| import sys | |||
| import time | |||
| from collections import OrderedDict | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| from tqdm import tqdm | |||
| from transformers.optimization import AdamW, get_linear_schedule_with_warmup | |||
| from .....utils.nlp.space import ontology | |||
| from ..metrics.metrics_tracker import MetricsTracker | |||
| def get_logger(log_path, name='default'): | |||
| logger = logging.getLogger(name) | |||
| logger.propagate = False | |||
| logger.setLevel(logging.DEBUG) | |||
| formatter = logging.Formatter('%(message)s') | |||
| sh = logging.StreamHandler(sys.stdout) | |||
| sh.setFormatter(formatter) | |||
| logger.addHandler(sh) | |||
| fh = logging.FileHandler(log_path, mode='w') | |||
| fh.setFormatter(formatter) | |||
| logger.addHandler(fh) | |||
| return logger | |||
| class Trainer(object): | |||
| def __init__(self, | |||
| model, | |||
| to_tensor, | |||
| config, | |||
| logger=None, | |||
| lr_scheduler=None, | |||
| optimizer=None, | |||
| reader=None, | |||
| evaluator=None): | |||
| self.to_tensor = to_tensor | |||
| self.do_train = config.do_train | |||
| self.do_infer = config.do_infer | |||
| self.is_decreased_valid_metric = config.Trainer.valid_metric_name[ | |||
| 0] == '-' | |||
| self.valid_metric_name = config.Trainer.valid_metric_name[1:] | |||
| self.num_epochs = config.Trainer.num_epochs | |||
| # self.save_dir = config.Trainer.save_dir | |||
| self.log_steps = config.Trainer.log_steps | |||
| self.valid_steps = config.Trainer.valid_steps | |||
| self.save_checkpoint = config.Trainer.save_checkpoint | |||
| self.save_summary = config.Trainer.save_summary | |||
| self.lr = config.Model.lr | |||
| self.weight_decay = config.Model.weight_decay | |||
| self.batch_size = config.Trainer.batch_size | |||
| self.gradient_accumulation_steps = config.Model.gradient_accumulation_steps | |||
| self.warmup_steps = config.Model.warmup_steps | |||
| self.gpu = config.Trainer.gpu | |||
| self.lr_scheduler = lr_scheduler | |||
| self.optimizer = optimizer | |||
| self.model = model | |||
| self.func_model = self.model.module if self.gpu > 1 else self.model | |||
| self.reader = reader | |||
| self.evaluator = evaluator | |||
| self.tokenizer = reader.tokenizer | |||
| # if not os.path.exists(self.save_dir): | |||
| # os.makedirs(self.save_dir) | |||
| # self.logger = logger or get_logger(os.path.join(self.save_dir, "trainer.log"), "trainer") | |||
| self.logger = logger or get_logger('trainer.log', 'trainer') | |||
| self.batch_metrics_tracker = MetricsTracker() | |||
| self.token_metrics_tracker = MetricsTracker() | |||
| self.best_valid_metric = float( | |||
| 'inf' if self.is_decreased_valid_metric else '-inf') | |||
| self.epoch = 0 | |||
| def decode_generated_bspn_resp(self, generated): | |||
| """ | |||
| decode generated | |||
| return decoded ('bspn', 'resp') | |||
| """ | |||
| decoded = {} | |||
| eos_r_id = self.reader.eos_r_id | |||
| eos_b_id = self.reader.eos_b_id | |||
| # eos_r may not exists if gpt2 generated repetitive words. | |||
| if eos_r_id in generated: | |||
| eos_r_idx = generated.index(eos_r_id) | |||
| else: | |||
| eos_r_idx = len(generated) - 1 | |||
| # self.logger.info('eos_r not in generated: ' + self.tokenizer.decode(generated)) | |||
| # predicted bspn, resp | |||
| eos_b_idx = generated.index(eos_b_id) | |||
| decoded['bspn'] = generated[:eos_b_idx + 1] | |||
| decoded['resp'] = generated[eos_b_idx + 1:eos_r_idx + 1] | |||
| return decoded | |||
| def decode_generated_act_resp(self, generated): | |||
| """ | |||
| decode generated | |||
| return decoded['resp'] ('bspn', 'aspn') | |||
| """ | |||
| decoded = {} | |||
| eos_a_id = self.reader.eos_a_id | |||
| eos_r_id = self.reader.eos_r_id | |||
| # eos_b_id = self.reader.eos_b_id | |||
| # eos_r may not exists if gpt2 generated repetitive words. | |||
| if eos_r_id in generated: | |||
| eos_r_idx = generated.index(eos_r_id) | |||
| else: | |||
| eos_r_idx = len(generated) - 1 | |||
| msg = 'eos_r not in generated: ' + self.tokenizer.decode(generated) | |||
| self.logger.info(msg) | |||
| if self.reader.use_true_curr_aspn: # only predict resp | |||
| decoded['resp'] = generated[:eos_r_idx + 1] | |||
| else: # predicted aspn, resp | |||
| eos_a_idx = generated.index(eos_a_id) | |||
| decoded['aspn'] = generated[:eos_a_idx + 1] | |||
| decoded['resp'] = generated[eos_a_idx + 1:eos_r_idx + 1] | |||
| return decoded | |||
| def decode_generated_bspn(self, generated): | |||
| eos_b_id = self.reader.eos_b_id | |||
| if eos_b_id in generated: | |||
| eos_b_idx = generated.index(eos_b_id) | |||
| else: | |||
| eos_b_idx = len(generated) - 1 | |||
| return generated[:eos_b_idx + 1] | |||
| def set_optimizers(self): | |||
| """ | |||
| Setup the optimizer and the learning rate scheduler. | |||
| from transformers.Trainer | |||
| parameters from cfg: lr (1e-3); warmup_steps | |||
| """ | |||
| # Prepare optimizer and schedule (linear warmup and decay) | |||
| no_decay = ['bias', 'norm.weight'] | |||
| optimizer_grouped_parameters = [ | |||
| { | |||
| 'params': [ | |||
| p for n, p in self.model.named_parameters() | |||
| if not any(nd in n for nd in no_decay) | |||
| ], | |||
| 'weight_decay': | |||
| self.weight_decay, | |||
| }, | |||
| { | |||
| 'params': [ | |||
| p for n, p in self.model.named_parameters() | |||
| if any(nd in n for nd in no_decay) | |||
| ], | |||
| 'weight_decay': | |||
| 0.0, | |||
| }, | |||
| ] | |||
| optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr) | |||
| num_training_steps = \ | |||
| self.reader.set_stats['train']['num_training_steps_per_epoch'] \ | |||
| * self.num_epochs \ | |||
| // self.gradient_accumulation_steps | |||
| num_warmup_steps = self.warmup_steps if self.warmup_steps >= 0 else int( | |||
| num_training_steps * 0.1) | |||
| lr_scheduler = get_linear_schedule_with_warmup( | |||
| optimizer, | |||
| num_warmup_steps=num_warmup_steps, | |||
| num_training_steps=num_training_steps) | |||
| self.optimizer = optimizer | |||
| self.lr_scheduler = lr_scheduler | |||
| def train(self, train_data, dev_data): | |||
| # log info | |||
| set_stats = self.reader.set_stats['train'] | |||
| self.logger.info('***** Running training *****') | |||
| self.logger.info( | |||
| ' Num Training steps(one turn in a batch of dialogs) per epoch = %d', | |||
| set_stats['num_training_steps_per_epoch']) | |||
| self.logger.info(' Num Turns = %d', set_stats['num_turns']) | |||
| self.logger.info(' Num Dialogs = %d', set_stats['num_dials']) | |||
| self.logger.info(' Num Epochs = %d', self.num_epochs) | |||
| self.logger.info(' Batch size = %d', self.batch_size) | |||
| self.logger.info(' Gradient Accumulation steps = %d', | |||
| self.gradient_accumulation_steps) | |||
| steps = set_stats[ | |||
| 'num_training_steps_per_epoch'] * self.num_epochs // self.gradient_accumulation_steps | |||
| msg = ' Total optimization steps = %d' % steps | |||
| self.logger.info(msg) | |||
| # begin training | |||
| num_epochs = self.num_epochs - self.epoch | |||
| for epoch in range(num_epochs): | |||
| self.train_epoch(train_data=train_data, dev_data=dev_data) | |||
| def train_epoch(self, train_data, dev_data): | |||
| """ | |||
| Train an epoch. | |||
| """ | |||
| raise NotImplementedError | |||
| def infer(self, data_type): | |||
| """ | |||
| Inference interface. | |||
| """ | |||
| raise NotImplementedError | |||
| def forward(self, turn, old_pv_turn): | |||
| """ | |||
| one turn inference | |||
| """ | |||
| raise NotImplementedError | |||
| def save(self, is_best=False): | |||
| """ save """ | |||
| train_state = { | |||
| 'epoch': self.epoch, | |||
| 'best_valid_metric': self.best_valid_metric, | |||
| 'optimizer': self.optimizer.state_dict() | |||
| } | |||
| if self.lr_scheduler is not None: | |||
| train_state['lr_scheduler'] = self.lr_scheduler.state_dict() | |||
| # Save checkpoint | |||
| if self.save_checkpoint: | |||
| model_file = os.path.join(self.save_dir, | |||
| f'state_epoch_{self.epoch}.model') | |||
| torch.save(self.model.state_dict(), model_file) | |||
| self.logger.info(f"Saved model state to '{model_file}'") | |||
| train_file = os.path.join(self.save_dir, | |||
| f'state_epoch_{self.epoch}.train') | |||
| torch.save(train_state, train_file) | |||
| self.logger.info(f"Saved train state to '{train_file}'") | |||
| # Save current best model | |||
| if is_best: | |||
| best_model_file = os.path.join(self.save_dir, 'best.model') | |||
| torch.save(self.model.state_dict(), best_model_file) | |||
| best_train_file = os.path.join(self.save_dir, 'best.train') | |||
| torch.save(train_state, best_train_file) | |||
| self.logger.info( | |||
| f"Saved best model state to '{best_model_file}' with new best valid metric " | |||
| f'{self.valid_metric_name.upper()}={self.best_valid_metric:.3f}' | |||
| ) | |||
| def load(self): | |||
| """ load """ | |||
| def _load_model_state(): | |||
| model_state_dict = torch.load( | |||
| f'{self.func_model.init_checkpoint}', | |||
| map_location=lambda storage, loc: storage) | |||
| if 'module.' in list(model_state_dict.keys())[0]: | |||
| new_model_state_dict = OrderedDict() | |||
| for k, v in model_state_dict.items(): | |||
| assert k[:7] == 'module.' | |||
| new_model_state_dict[k[7:]] = v | |||
| model_state_dict = new_model_state_dict | |||
| new_model_state_dict = OrderedDict() | |||
| parameters = { | |||
| name: param | |||
| for name, param in self.func_model.named_parameters() | |||
| } | |||
| for name, param in model_state_dict.items(): | |||
| if name in parameters: | |||
| if param.shape != parameters[name].shape: | |||
| assert hasattr(param, 'numpy') | |||
| arr = param.numpy() | |||
| z = np.random.normal( | |||
| scale=self.func_model.initializer_range, | |||
| size=parameters[name].shape).astype('float32') | |||
| if name == 'embedder.token_embedding.weight': | |||
| z[-param.shape[0]:] = arr | |||
| print( | |||
| f'part of parameter({name}) random normlize initialize' | |||
| ) | |||
| else: | |||
| if z.shape[0] < param.shape[0]: | |||
| z = arr[:z.shape[0]] | |||
| print(f'part of parameter({name}) are dropped') | |||
| else: | |||
| z[:param.shape[0]] = arr | |||
| print( | |||
| f'part of parameter({name}) random normlize initialize' | |||
| ) | |||
| dtype, device = param.dtype, param.device | |||
| z = torch.tensor(z, dtype=dtype, device=device) | |||
| new_model_state_dict[name] = z | |||
| else: | |||
| new_model_state_dict[name] = param | |||
| else: | |||
| print(f'parameter({name}) are dropped') | |||
| model_state_dict = new_model_state_dict | |||
| for name in parameters: | |||
| if name not in model_state_dict: | |||
| if parameters[name].requires_grad: | |||
| print(f'parameter({name}) random normlize initialize') | |||
| z = np.random.normal( | |||
| scale=self.func_model.initializer_range, | |||
| size=parameters[name].shape).astype('float32') | |||
| dtype, device = parameters[name].dtype, parameters[ | |||
| name].device | |||
| model_state_dict[name] = torch.tensor( | |||
| z, dtype=dtype, device=device) | |||
| else: | |||
| model_state_dict[name] = parameters[name] | |||
| self.func_model.load_state_dict(model_state_dict) | |||
| self.logger.info( | |||
| f"Loaded model state from '{self.func_model.init_checkpoint}.model'" | |||
| ) | |||
| def _load_train_state(): | |||
| train_file = f'{self.func_model.init_checkpoint}.train' | |||
| if os.path.exists(train_file): | |||
| train_state_dict = torch.load( | |||
| train_file, map_location=lambda storage, loc: storage) | |||
| self.epoch = train_state_dict['epoch'] | |||
| self.best_valid_metric = train_state_dict['best_valid_metric'] | |||
| if self.optimizer is not None and 'optimizer' in train_state_dict: | |||
| self.optimizer.load_state_dict( | |||
| train_state_dict['optimizer']) | |||
| if self.lr_scheduler is not None and 'lr_scheduler' in train_state_dict: | |||
| self.lr_scheduler.load_state_dict( | |||
| train_state_dict['lr_scheduler']) | |||
| self.logger.info( | |||
| f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " | |||
| f'best_valid_metric={self.best_valid_metric:.3f})') | |||
| else: | |||
| self.logger.info('Loaded no train state') | |||
| if self.func_model.init_checkpoint is None: | |||
| self.logger.info('Loaded no model !!!') | |||
| return | |||
| if self.do_train: | |||
| _load_model_state() | |||
| return | |||
| if self.do_infer: | |||
| _load_model_state() | |||
| _load_train_state() | |||
| class MultiWOZTrainer(Trainer): | |||
| def __init__(self, | |||
| model, | |||
| to_tensor, | |||
| config, | |||
| logger=None, | |||
| lr_scheduler=None, | |||
| optimizer=None, | |||
| reader=None, | |||
| evaluator=None): | |||
| super(MultiWOZTrainer, | |||
| self).__init__(model, to_tensor, config, logger, lr_scheduler, | |||
| optimizer, reader, evaluator) | |||
| def train_epoch(self, train_data, dev_data): | |||
| """ | |||
| Train an epoch. | |||
| """ | |||
| times = [] | |||
| epoch_step = 0 | |||
| global_step = 0 | |||
| tr_batch_loss = 0.0 | |||
| tr_token_loss = 0.0 | |||
| self.epoch += 1 | |||
| self.batch_metrics_tracker.clear() | |||
| self.token_metrics_tracker.clear() | |||
| num_training_steps = \ | |||
| self.reader.set_stats['train']['num_training_steps_per_epoch'] // \ | |||
| self.gradient_accumulation_steps # similar to the original num_batches | |||
| self.model.zero_grad() | |||
| data_iterator = self.reader.get_data_iterator(all_batches=train_data) | |||
| for batch_idx, dial_batch in enumerate(data_iterator): | |||
| pv_batch = [] | |||
| for turn_num, turn_batch in enumerate(dial_batch): | |||
| first_turn = (turn_num == 0) | |||
| samples, pv_batch = self.reader.convert_batch_turn( | |||
| turn_batch, pv_batch, first_turn) | |||
| batch, batch_size = self.reader.collate_fn_multi_turn( | |||
| samples=samples) | |||
| batch = type(batch)( | |||
| map(lambda kv: (kv[0], self.to_tensor(kv[1])), | |||
| batch.items())) | |||
| # Do a training iteration | |||
| start_time = time.time() | |||
| metrics = self.model(batch, is_training=True) | |||
| if self.gpu > 1: | |||
| for metric in metrics: | |||
| if metric is not None: | |||
| assert len(metric) == self.gpu | |||
| nll, token_nll, token_num = metrics | |||
| metrics = {} | |||
| token_num = torch.sum(token_num) | |||
| token_nll = \ | |||
| torch.sum(nll) * (batch_size / self.gpu) / \ | |||
| token_num | |||
| nll = torch.mean(nll) | |||
| metrics['token_num'] = token_num | |||
| metrics['token_nll'] = token_nll | |||
| metrics['nll'] = nll | |||
| loss = token_nll if self.func_model.token_loss else nll | |||
| metrics['loss'] = loss | |||
| else: | |||
| loss = metrics['loss'] | |||
| self.func_model._optimize( | |||
| loss, do_update=False, optimizer=self.optimizer) | |||
| metrics = { | |||
| k: v.cpu().detach().numpy() | |||
| if isinstance(v, torch.Tensor) else v | |||
| for k, v in metrics.items() | |||
| } | |||
| token_num = metrics.pop('token_num', None) | |||
| # bow_num = metrics.pop("bow_num", None) | |||
| elapsed = time.time() - start_time | |||
| times.append(elapsed) | |||
| epoch_step += 1 | |||
| tr_batch_loss += metrics['nll'] | |||
| tr_token_loss += metrics['token_nll'] | |||
| batch_metrics = { | |||
| k: v | |||
| for k, v in metrics.items() if 'token' not in k | |||
| } | |||
| token_metrics = { | |||
| k: v | |||
| for k, v in metrics.items() if 'token' in k | |||
| } | |||
| self.batch_metrics_tracker.update(batch_metrics, batch_size) | |||
| self.token_metrics_tracker.update(token_metrics, token_num) | |||
| if (epoch_step % self.gradient_accumulation_steps == 0) or \ | |||
| (epoch_step == self.reader.set_stats['train']['num_training_steps_per_epoch']): | |||
| self.optimizer.step() | |||
| self.lr_scheduler.step() | |||
| self.optimizer.zero_grad() | |||
| global_step += 1 | |||
| if self.log_steps > 0 and global_step % self.log_steps == 0: | |||
| batch_metrics_message = self.batch_metrics_tracker.value( | |||
| ) | |||
| token_metrics_message = self.token_metrics_tracker.value( | |||
| ) | |||
| message_prefix = f'[Train][{self.epoch}][{global_step}/{num_training_steps}]' | |||
| avg_time = f'AVG_Time-{sum(times[-self.log_steps:]) / self.log_steps:.3f}' | |||
| message = ' '.join([ | |||
| message_prefix, batch_metrics_message, | |||
| token_metrics_message, avg_time | |||
| ]) | |||
| self.logger.info(message) | |||
| self.logger.info('-' * 150) | |||
| avg_batch_loss = tr_batch_loss / epoch_step | |||
| avg_token_loss = tr_token_loss / epoch_step | |||
| batch_metrics_message = self.batch_metrics_tracker.summary() | |||
| token_metrics_message = self.token_metrics_tracker.summary() | |||
| message_prefix = f'[Valid][{self.epoch}]' | |||
| message = ' '.join([ | |||
| message_prefix, batch_metrics_message, token_metrics_message, | |||
| str(avg_batch_loss), | |||
| str(avg_token_loss) | |||
| ]) | |||
| self.logger.info(message) | |||
| cur_valid_metric = self.batch_metrics_tracker.get( | |||
| self.valid_metric_name) | |||
| if self.is_decreased_valid_metric: | |||
| is_best = cur_valid_metric < self.best_valid_metric | |||
| else: | |||
| is_best = cur_valid_metric > self.best_valid_metric | |||
| if is_best: | |||
| self.best_valid_metric = cur_valid_metric | |||
| self.save(is_best) | |||
| self.logger.info('-' * 150) | |||
| return | |||
| def infer(self, data_type='test'): | |||
| """ | |||
| Inference interface. | |||
| """ | |||
| self.logger.info('Generation starts ...') | |||
| infer_save_file = os.path.join(self.save_dir, | |||
| f'infer_{self.epoch}.result.json') | |||
| infer_samples_save_file = os.path.join( | |||
| self.save_dir, f'infer_samples_{self.epoch}.result.json') | |||
| # Inference | |||
| result_collection = {} | |||
| begin_time = time.time() | |||
| eval_data = self.reader.get_eval_data(data_type) | |||
| set_stats = self.reader.set_stats[data_type] | |||
| self.logger.info('***** Running Evaluation *****') | |||
| self.logger.info(' Num Turns = %d', set_stats['num_turns']) | |||
| with torch.no_grad(): | |||
| pbar = tqdm(eval_data) | |||
| for dial_idx, dialog in enumerate(pbar): | |||
| pv_turn = {} | |||
| for turn_idx, turn in enumerate(dialog): | |||
| first_turn = (turn_idx == 0) | |||
| inputs, prompt_id = self.reader.convert_turn_eval( | |||
| turn, pv_turn, first_turn) | |||
| batch, batch_size = self.reader.collate_fn_multi_turn( | |||
| samples=[inputs]) | |||
| batch = type(batch)( | |||
| map(lambda kv: (kv[0], self.to_tensor(kv[1])), | |||
| batch.items())) | |||
| if self.reader.use_true_curr_bspn: # generate act, response | |||
| max_len = 60 | |||
| if not self.reader.use_true_curr_aspn: | |||
| max_len = 80 | |||
| outputs = self.func_model.infer( | |||
| inputs=batch, | |||
| start_id=prompt_id, | |||
| eos_id=self.reader.eos_r_id, | |||
| max_gen_len=max_len) | |||
| # resp_gen, need to trim previous context | |||
| generated = outputs[0].cpu().numpy().tolist() | |||
| try: | |||
| decoded = self.decode_generated_act_resp(generated) | |||
| except ValueError as exception: | |||
| self.logger.info(str(exception)) | |||
| self.logger.info(self.tokenizer.decode(generated)) | |||
| decoded = {'resp': [], 'bspn': [], 'aspn': []} | |||
| else: # predict bspn, access db, then generate act and resp | |||
| outputs = self.func_model.infer( | |||
| inputs=batch, | |||
| start_id=prompt_id, | |||
| eos_id=self.reader.eos_b_id, | |||
| max_gen_len=60) | |||
| generated_bs = outputs[0].cpu().numpy().tolist() | |||
| bspn_gen = self.decode_generated_bspn(generated_bs) | |||
| # check DB result | |||
| if self.reader.use_true_db_pointer: # To control whether current db is ground truth | |||
| db = turn['db'] | |||
| else: | |||
| db_result = self.reader.bspan_to_DBpointer( | |||
| self.tokenizer.decode(bspn_gen), | |||
| turn['turn_domain']) | |||
| assert len(turn['db']) == 4 | |||
| book_result = turn['db'][2] | |||
| assert isinstance(db_result, str) | |||
| db = \ | |||
| [self.reader.sos_db_id] + \ | |||
| self.tokenizer.convert_tokens_to_ids([db_result]) + \ | |||
| [book_result] + \ | |||
| [self.reader.eos_db_id] | |||
| prompt_id = self.reader.sos_a_id | |||
| prev_input = torch.tensor(bspn_gen + db) | |||
| if self.func_model.use_gpu: | |||
| prev_input = prev_input.cuda() | |||
| outputs_db = self.func_model.infer( | |||
| inputs=batch, | |||
| start_id=prompt_id, | |||
| eos_id=self.reader.eos_r_id, | |||
| max_gen_len=80, | |||
| prev_input=prev_input) | |||
| generated_ar = outputs_db[0].cpu().numpy().tolist() | |||
| try: | |||
| decoded = self.decode_generated_act_resp( | |||
| generated_ar) | |||
| decoded['bspn'] = bspn_gen | |||
| except ValueError as exception: | |||
| self.logger.info(str(exception)) | |||
| self.logger.info( | |||
| self.tokenizer.decode(generated_ar)) | |||
| decoded = {'resp': [], 'bspn': [], 'aspn': []} | |||
| turn['resp_gen'] = decoded['resp'] | |||
| turn['bspn_gen'] = turn[ | |||
| 'bspn'] if self.reader.use_true_curr_bspn else decoded[ | |||
| 'bspn'] | |||
| turn['aspn_gen'] = turn[ | |||
| 'aspn'] if self.reader.use_true_curr_aspn else decoded[ | |||
| 'aspn'] | |||
| turn['dspn_gen'] = turn['dspn'] | |||
| pv_turn['labels'] = inputs[ | |||
| 'labels'] # all true previous context | |||
| pv_turn['resp'] = turn[ | |||
| 'resp'] if self.reader.use_true_prev_resp else decoded[ | |||
| 'resp'] | |||
| if not self.reader.use_true_curr_bspn: | |||
| pv_turn['bspn'] = turn[ | |||
| 'bspn'] if self.reader.use_true_prev_bspn else decoded[ | |||
| 'bspn'] | |||
| pv_turn['db'] = turn[ | |||
| 'db'] if self.reader.use_true_prev_bspn else db | |||
| pv_turn['aspn'] = turn[ | |||
| 'aspn'] if self.reader.use_true_prev_aspn else decoded[ | |||
| 'aspn'] | |||
| tmp_dialog_result = self.reader.inverse_transpose_turn(dialog) | |||
| result_collection.update(tmp_dialog_result) | |||
| # compute tmp scores | |||
| results, _ = self.reader.wrap_result_lm(tmp_dialog_result) | |||
| bleu, success, match = self.evaluator.validation_metric( | |||
| results) | |||
| score = 0.5 * (success + match) + bleu | |||
| pbar.set_description( | |||
| 'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' % | |||
| (match, success, bleu, score)) | |||
| # compute scores | |||
| results, _ = self.reader.wrap_result_lm(result_collection) | |||
| bleu, success, match = self.evaluator.validation_metric(results) | |||
| score = 0.5 * (success + match) + bleu | |||
| # log results | |||
| metrics_message = 'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' %\ | |||
| (match, success, bleu, score) | |||
| message_prefix = f'[Infer][{self.epoch}]' | |||
| time_cost = f'TIME-{time.time() - begin_time:.3f}' | |||
| message = ' '.join([message_prefix, metrics_message, time_cost]) | |||
| self.logger.info(message) | |||
| # save results | |||
| eval_results = { | |||
| 'bleu': bleu, | |||
| 'success': success, | |||
| 'match': match, | |||
| 'score': score, | |||
| 'result': message | |||
| } | |||
| with open(infer_save_file, 'w') as fp: | |||
| json.dump(eval_results, fp, indent=2) | |||
| self.logger.info(f'Saved inference results to {infer_save_file}') | |||
| with open(infer_samples_save_file, 'w') as fp: | |||
| for sample in results: | |||
| line = json.dumps(sample) | |||
| fp.write(line) | |||
| fp.write('\n') | |||
| self.logger.info( | |||
| f'Saved inference samples to {infer_samples_save_file}') | |||
| return | |||
| def _get_turn_domain(self, old_pv_turn, bspn_gen_ids, first_turn): | |||
| def _get_slots(constraint): | |||
| domain_name = '' | |||
| slots = {} | |||
| for item in constraint: | |||
| if item in ontology.placeholder_tokens: | |||
| continue | |||
| if item in ontology.all_domains_with_bracket: | |||
| domain_name = item | |||
| slots[domain_name] = set() | |||
| else: | |||
| assert domain_name in ontology.all_domains_with_bracket | |||
| slots[domain_name].add(item) | |||
| return slots | |||
| turn_domain = [] | |||
| if first_turn and len(bspn_gen_ids) == 0: | |||
| turn_domain = ['[general]'] | |||
| return turn_domain | |||
| bspn_token = self.tokenizer.convert_ids_to_tokens(bspn_gen_ids) | |||
| turn_slots = _get_slots(bspn_token) | |||
| if first_turn: | |||
| return list(turn_slots.keys()) | |||
| assert 'bspn' in old_pv_turn | |||
| pv_bspn_token = self.tokenizer.convert_ids_to_tokens( | |||
| old_pv_turn['bspn']) | |||
| pv_turn_slots = _get_slots(pv_bspn_token) | |||
| for domain, value in turn_slots.items(): | |||
| pv_value = pv_turn_slots[ | |||
| domain] if domain in pv_turn_slots else set() | |||
| if len(value - pv_value) > 0 or len(pv_value - value): | |||
| turn_domain.append(domain) | |||
| if len(turn_domain) == 0: | |||
| turn_domain = list(turn_slots.keys()) | |||
| return turn_domain | |||
| def forward(self, turn, old_pv_turn): | |||
| with torch.no_grad(): | |||
| first_turn = True if len(old_pv_turn) == 0 else False | |||
| inputs, prompt_id = self.reader.convert_turn_eval( | |||
| turn, old_pv_turn, first_turn) | |||
| batch, batch_size = self.reader.collate_fn_multi_turn( | |||
| samples=[inputs]) | |||
| batch = type(batch)( | |||
| map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items())) | |||
| pv_turn = {} | |||
| outputs = self.func_model.infer( | |||
| inputs=batch, | |||
| start_id=prompt_id, | |||
| eos_id=self.reader.eos_b_id, | |||
| max_gen_len=60) | |||
| generated_bs = outputs[0].cpu().numpy().tolist() | |||
| bspn_gen = self.decode_generated_bspn(generated_bs) | |||
| turn_domain = self._get_turn_domain(old_pv_turn, bspn_gen, | |||
| first_turn) | |||
| db_result = self.reader.bspan_to_DBpointer( | |||
| self.tokenizer.decode(bspn_gen), turn_domain) | |||
| assert isinstance(db_result, str) | |||
| db = \ | |||
| [self.reader.sos_db_id] + \ | |||
| self.tokenizer.convert_tokens_to_ids([db_result]) + \ | |||
| [self.reader.eos_db_id] | |||
| prompt_id = self.reader.sos_a_id | |||
| prev_input = torch.tensor(bspn_gen + db) | |||
| if self.func_model.use_gpu: | |||
| prev_input = prev_input.cuda() | |||
| outputs_db = self.func_model.infer( | |||
| inputs=batch, | |||
| start_id=prompt_id, | |||
| eos_id=self.reader.eos_r_id, | |||
| max_gen_len=80, | |||
| prev_input=prev_input) | |||
| generated_ar = outputs_db[0].cpu().numpy().tolist() | |||
| decoded = self.decode_generated_act_resp(generated_ar) | |||
| decoded['bspn'] = bspn_gen | |||
| pv_turn['labels'] = inputs['labels'] | |||
| pv_turn['resp'] = decoded['resp'] | |||
| pv_turn['bspn'] = decoded['bspn'] | |||
| pv_turn['db'] = db | |||
| pv_turn['aspn'] = decoded['aspn'] | |||
| return pv_turn | |||
| @@ -0,0 +1,821 @@ | |||
| """ | |||
| Trainer class. | |||
| """ | |||
| import logging | |||
| import os | |||
| import sys | |||
| import time | |||
| from collections import OrderedDict | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| from tqdm import tqdm | |||
| from transformers.optimization import AdamW, get_linear_schedule_with_warmup | |||
| from ..metrics.metrics_tracker import MetricsTracker | |||
| def get_logger(log_path, name='default'): | |||
| logger = logging.getLogger(name) | |||
| logger.propagate = False | |||
| logger.setLevel(logging.DEBUG) | |||
| formatter = logging.Formatter('%(message)s') | |||
| sh = logging.StreamHandler(sys.stdout) | |||
| sh.setFormatter(formatter) | |||
| logger.addHandler(sh) | |||
| fh = logging.FileHandler(log_path, mode='w') | |||
| fh.setFormatter(formatter) | |||
| logger.addHandler(fh) | |||
| return logger | |||
| class Trainer(object): | |||
| def __init__(self, | |||
| model, | |||
| to_tensor, | |||
| config, | |||
| reader=None, | |||
| logger=None, | |||
| lr_scheduler=None, | |||
| optimizer=None): | |||
| self.model = model | |||
| self.to_tensor = to_tensor | |||
| self.do_train = config.do_train | |||
| self.do_infer = config.do_infer | |||
| self.is_decreased_valid_metric = config.Trainer.valid_metric_name[ | |||
| 0] == '-' | |||
| self.valid_metric_name = config.Trainer.valid_metric_name[1:] | |||
| self.num_epochs = config.Trainer.num_epochs | |||
| self.save_dir = config.Trainer.save_dir | |||
| self.log_steps = config.Trainer.log_steps | |||
| self.valid_steps = config.Trainer.valid_steps | |||
| self.save_checkpoint = config.Trainer.save_checkpoint | |||
| self.save_summary = config.Trainer.save_summary | |||
| self.learning_method = config.Dataset.learning_method | |||
| self.weight_decay = config.Model.weight_decay | |||
| self.warmup_steps = config.Model.warmup_steps | |||
| self.batch_size_label = config.Trainer.batch_size_label | |||
| self.batch_size_nolabel = config.Trainer.batch_size_nolabel | |||
| self.gpu = config.Trainer.gpu | |||
| self.lr = config.Model.lr | |||
| self.model = model | |||
| self.func_model = self.model.module if self.gpu > 1 else self.model | |||
| self.reader = reader | |||
| self.tokenizer = reader.tokenizer | |||
| self.lr_scheduler = lr_scheduler | |||
| self.optimizer = optimizer | |||
| # if not os.path.exists(self.save_dir): | |||
| # os.makedirs(self.save_dir) | |||
| # self.logger = logger or get_logger(os.path.join(self.save_dir, "trainer.log"), "trainer") | |||
| self.logger = logger or get_logger('trainer.log', 'trainer') | |||
| self.batch_metrics_tracker_label = MetricsTracker() | |||
| self.token_metrics_tracker_label = MetricsTracker() | |||
| self.batch_metrics_tracker_nolabel = MetricsTracker() | |||
| self.token_metrics_tracker_nolabel = MetricsTracker() | |||
| self.best_valid_metric = float( | |||
| 'inf' if self.is_decreased_valid_metric else '-inf') | |||
| self.epoch = 0 | |||
| self.batch_num = 0 | |||
| def set_optimizers(self, num_training_steps_per_epoch): | |||
| """ | |||
| Setup the optimizer and the learning rate scheduler. | |||
| from transformers.Trainer | |||
| parameters from cfg: lr (1e-3); warmup_steps | |||
| """ | |||
| # Prepare optimizer and schedule (linear warmup and decay) | |||
| no_decay = ['bias', 'norm.weight'] | |||
| optimizer_grouped_parameters = [ | |||
| { | |||
| 'params': [ | |||
| p for n, p in self.model.named_parameters() | |||
| if not any(nd in n for nd in no_decay) | |||
| ], | |||
| 'weight_decay': | |||
| self.weight_decay, | |||
| }, | |||
| { | |||
| 'params': [ | |||
| p for n, p in self.model.named_parameters() | |||
| if any(nd in n for nd in no_decay) | |||
| ], | |||
| 'weight_decay': | |||
| 0.0, | |||
| }, | |||
| ] | |||
| optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr) | |||
| num_training_steps = num_training_steps_per_epoch * self.num_epochs | |||
| num_warmup_steps = self.warmup_steps if self.warmup_steps >= 0 else int( | |||
| num_training_steps * 0.1) | |||
| lr_scheduler = get_linear_schedule_with_warmup( | |||
| optimizer, | |||
| num_warmup_steps=num_warmup_steps, | |||
| num_training_steps=num_training_steps) | |||
| # reset optimizer and lr_scheduler | |||
| self.optimizer = optimizer | |||
| self.lr_scheduler = lr_scheduler | |||
| # log info | |||
| self.logger.info( | |||
| f'***** Running training: {self.learning_method} *****') | |||
| self.logger.info(' Num Epochs = %d', self.num_epochs) | |||
| self.logger.info( | |||
| ' Num Training steps(one turn in a batch of dialogs) per epoch = %d', | |||
| num_training_steps_per_epoch) | |||
| self.logger.info(' Batch size for labeled data = %d', | |||
| self.batch_size_label) | |||
| self.logger.info(' Batch size for unlabeled data = %d', | |||
| self.batch_size_nolabel) | |||
| self.logger.info(' Total optimization steps = %d', num_training_steps) | |||
| self.logger.info(' Total warmup steps = %d', num_warmup_steps) | |||
| self.logger.info('************************************') | |||
| def train(self, | |||
| train_label_iter, | |||
| train_nolabel_iter=None, | |||
| valid_label_iter=None, | |||
| valid_nolabel_iter=None): | |||
| # begin training | |||
| num_epochs = self.num_epochs - self.epoch | |||
| for epoch in range(num_epochs): | |||
| self.train_epoch( | |||
| train_label_iter=train_label_iter, | |||
| train_nolabel_iter=train_nolabel_iter, | |||
| valid_label_iter=valid_label_iter, | |||
| valid_nolabel_iter=valid_nolabel_iter) | |||
| def train_epoch(self, train_label_iter, train_nolabel_iter, | |||
| valid_label_iter, valid_nolabel_iter): | |||
| """ | |||
| Train an epoch. | |||
| """ | |||
| raise NotImplementedError | |||
| def evaluate(self, data_label_iter, data_nolabel_iter, need_save=True): | |||
| raise NotImplementedError | |||
| def infer(self, data_iter, num_batches=None): | |||
| raise NotImplementedError | |||
| def save(self, is_best=False): | |||
| """ save """ | |||
| train_state = { | |||
| 'epoch': self.epoch, | |||
| 'batch_num': self.batch_num, | |||
| 'best_valid_metric': self.best_valid_metric, | |||
| 'optimizer': self.optimizer.state_dict() | |||
| } | |||
| if self.lr_scheduler is not None: | |||
| train_state['lr_scheduler'] = self.lr_scheduler.state_dict() | |||
| # Save checkpoint | |||
| if self.save_checkpoint: | |||
| model_file = os.path.join(self.save_dir, | |||
| f'state_epoch_{self.epoch}.model') | |||
| torch.save(self.model.state_dict(), model_file) | |||
| self.logger.info(f"Saved model state to '{model_file}'") | |||
| train_file = os.path.join(self.save_dir, | |||
| f'state_epoch_{self.epoch}.train') | |||
| torch.save(train_state, train_file) | |||
| self.logger.info(f"Saved train state to '{train_file}'") | |||
| # Save current best model | |||
| if is_best: | |||
| best_model_file = os.path.join(self.save_dir, 'best.model') | |||
| torch.save(self.model.state_dict(), best_model_file) | |||
| best_train_file = os.path.join(self.save_dir, 'best.train') | |||
| torch.save(train_state, best_train_file) | |||
| self.logger.info( | |||
| f"Saved best model state to '{best_model_file}' with new best valid metric " | |||
| f'{self.valid_metric_name.upper()}={self.best_valid_metric:.3f}' | |||
| ) | |||
| def load(self): | |||
| """ load """ | |||
| def _load_model_state(): | |||
| model_state_dict = torch.load( | |||
| f'{self.func_model.init_checkpoint}.model', | |||
| map_location=lambda storage, loc: storage) | |||
| if 'module.' in list(model_state_dict.keys())[0]: | |||
| new_model_state_dict = OrderedDict() | |||
| for k, v in model_state_dict.items(): | |||
| assert k[:7] == 'module.' | |||
| new_model_state_dict[k[7:]] = v | |||
| model_state_dict = new_model_state_dict | |||
| new_model_state_dict = OrderedDict() | |||
| parameters = { | |||
| name: param | |||
| for name, param in self.func_model.named_parameters() | |||
| } | |||
| for name, param in model_state_dict.items(): | |||
| if name in parameters: | |||
| if param.shape != parameters[name].shape: | |||
| assert hasattr(param, 'numpy') | |||
| arr = param.numpy() | |||
| z = np.random.normal( | |||
| scale=self.func_model.initializer_range, | |||
| size=parameters[name].shape).astype('float32') | |||
| if name == 'embedder.token_embedding.weight': | |||
| z[-param.shape[0]:] = arr | |||
| print( | |||
| f'part of parameter({name}) random normlize initialize' | |||
| ) | |||
| else: | |||
| if z.shape[0] < param.shape[0]: | |||
| z = arr[:z.shape[0]] | |||
| print(f'part of parameter({name}) are dropped') | |||
| else: | |||
| z[:param.shape[0]] = arr | |||
| print( | |||
| f'part of parameter({name}) random normlize initialize' | |||
| ) | |||
| dtype, device = param.dtype, param.device | |||
| z = torch.tensor(z, dtype=dtype, device=device) | |||
| new_model_state_dict[name] = z | |||
| else: | |||
| new_model_state_dict[name] = param | |||
| else: | |||
| print(f'parameter({name}) are dropped') | |||
| model_state_dict = new_model_state_dict | |||
| for name in parameters: | |||
| if name not in model_state_dict: | |||
| if parameters[name].requires_grad: | |||
| print(f'parameter({name}) random normlize initialize') | |||
| z = np.random.normal( | |||
| scale=self.func_model.initializer_range, | |||
| size=parameters[name].shape).astype('float32') | |||
| dtype, device = parameters[name].dtype, parameters[ | |||
| name].device | |||
| model_state_dict[name] = torch.tensor( | |||
| z, dtype=dtype, device=device) | |||
| else: | |||
| model_state_dict[name] = parameters[name] | |||
| self.func_model.load_state_dict(model_state_dict) | |||
| self.logger.info( | |||
| f"Loaded model state from '{self.func_model.init_checkpoint}.model'" | |||
| ) | |||
| def _load_train_state(): | |||
| train_file = f'{self.func_model.init_checkpoint}.train' | |||
| if os.path.exists(train_file): | |||
| train_state_dict = torch.load( | |||
| train_file, map_location=lambda storage, loc: storage) | |||
| self.epoch = train_state_dict['epoch'] | |||
| self.best_valid_metric = train_state_dict['best_valid_metric'] | |||
| if self.optimizer is not None and 'optimizer' in train_state_dict: | |||
| self.optimizer.load_state_dict( | |||
| train_state_dict['optimizer']) | |||
| if self.lr_scheduler is not None and 'lr_scheduler' in train_state_dict: | |||
| self.lr_scheduler.load_state_dict( | |||
| train_state_dict['lr_scheduler']) | |||
| self.logger.info( | |||
| f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " | |||
| f'best_valid_metric={self.best_valid_metric:.3f})') | |||
| else: | |||
| self.logger.info('Loaded no train state') | |||
| if self.func_model.init_checkpoint is None: | |||
| self.logger.info('Loaded no model !!!') | |||
| return | |||
| _load_model_state() | |||
| _load_train_state() | |||
| class IntentTrainer(Trainer): | |||
| def __init__(self, model, to_tensor, config, reader=None): | |||
| super(IntentTrainer, self).__init__(model, to_tensor, config, reader) | |||
| self.example = config.Model.example | |||
| self.can_norm = config.Trainer.can_norm | |||
| def can_normalization(self, y_pred, y_true, ex_data_iter): | |||
| # compute ACC | |||
| acc_original = np.mean([y_pred.argmax(1) == y_true]) | |||
| message = 'original acc: %s' % acc_original | |||
| # compute uncertainty | |||
| k = 3 | |||
| y_pred_topk = np.sort(y_pred, axis=1)[:, -k:] | |||
| y_pred_topk /= y_pred_topk.sum(axis=1, keepdims=True) | |||
| y_pred_uncertainty =\ | |||
| -(y_pred_topk * np.log(y_pred_topk)).sum(1) / np.log(k) | |||
| # choose threshold | |||
| # print(np.sort(y_pred_uncertainty)[-100:].tolist()) | |||
| threshold = 0.7 | |||
| y_pred_confident = y_pred[y_pred_uncertainty < threshold] | |||
| y_pred_unconfident = y_pred[y_pred_uncertainty >= threshold] | |||
| y_true_confident = y_true[y_pred_uncertainty < threshold] | |||
| y_true_unconfident = y_true[y_pred_uncertainty >= threshold] | |||
| # compute ACC again for high and low confidence sets | |||
| acc_confident = (y_pred_confident.argmax(1) == y_true_confident).mean() \ | |||
| if len(y_true_confident) else 0. | |||
| acc_unconfident = (y_pred_unconfident.argmax(1) == y_true_unconfident).mean() \ | |||
| if len(y_true_unconfident) else 0. | |||
| message += ' (%s) confident acc: %s' % (len(y_true_confident), | |||
| acc_confident) | |||
| message += ' (%s) unconfident acc: %s' % (len(y_true_unconfident), | |||
| acc_unconfident) | |||
| # get prior distribution from training set | |||
| prior = np.zeros(self.func_model.num_intent) | |||
| for _, (batch, batch_size) in ex_data_iter: | |||
| for intent_label in batch['intent_label']: | |||
| prior[intent_label] += 1. | |||
| prior /= prior.sum() | |||
| # revise each sample from the low confidence set, and compute new ACC | |||
| right, alpha, iters = 0, 1, 1 | |||
| for i, y in enumerate(y_pred_unconfident): | |||
| Y = np.concatenate([y_pred_confident, y[None]], axis=0) | |||
| for j in range(iters): | |||
| Y = Y**alpha | |||
| Y /= Y.mean(axis=0, keepdims=True) | |||
| Y *= prior[None] | |||
| Y /= Y.sum(axis=1, keepdims=True) | |||
| y = Y[-1] | |||
| if y.argmax() == y_true_unconfident[i]: | |||
| right += 1 | |||
| # get final ACC | |||
| acc_final = \ | |||
| (acc_confident * len(y_pred_confident) + right) / \ | |||
| len(y_pred) | |||
| if len(y_pred_unconfident): | |||
| message += ' new unconfident acc: %s' % ( | |||
| right / len(y_pred_unconfident)) | |||
| else: | |||
| message += ' no unconfident predictions' | |||
| message += ' final acc: %s' % acc_final | |||
| return acc_original, acc_final, message | |||
| def train_epoch(self, train_label_iter, train_nolabel_iter, | |||
| valid_label_iter, valid_nolabel_iter): | |||
| """ | |||
| Train an epoch. | |||
| """ | |||
| times = [] | |||
| self.epoch += 1 | |||
| self.batch_metrics_tracker_label.clear() | |||
| self.token_metrics_tracker_label.clear() | |||
| self.batch_metrics_tracker_nolabel.clear() | |||
| self.token_metrics_tracker_nolabel.clear() | |||
| num_label_batches = len(train_label_iter) | |||
| num_nolabel_batches = len( | |||
| train_nolabel_iter) if train_nolabel_iter is not None else 0 | |||
| num_batches = max(num_label_batches, num_nolabel_batches) | |||
| train_label_iter_loop = iter(train_label_iter) | |||
| train_nolabel_iter_loop = iter( | |||
| train_nolabel_iter) if train_nolabel_iter is not None else None | |||
| report_for_unlabeled_data = True if train_nolabel_iter is not None else False | |||
| for batch_id in range(1, num_batches + 1): | |||
| # Do a training iteration | |||
| start_time = time.time() | |||
| batch_list, batch_size_list, with_label_list, loss_list, metrics_list = [], [], [], [], [] | |||
| data_file_list = [] | |||
| # collect batch for labeled data | |||
| try: | |||
| data_file_label, ( | |||
| batch_label, | |||
| batch_size_label) = next(train_label_iter_loop) | |||
| except StopIteration: | |||
| train_label_iter_loop = iter(train_label_iter) | |||
| data_file_label, ( | |||
| batch_label, | |||
| batch_size_label) = next(train_label_iter_loop) | |||
| batch_list.append(batch_label) | |||
| batch_size_list.append(batch_size_label) | |||
| with_label_list.append(True) | |||
| data_file_list.append(data_file_label) | |||
| # collect batch for unlabeled data | |||
| if train_nolabel_iter is not None: | |||
| try: | |||
| data_file_nolabel, ( | |||
| batch_nolabel, | |||
| batch_size_nolabel) = next(train_nolabel_iter_loop) | |||
| except StopIteration: | |||
| train_nolabel_iter_loop = iter(train_nolabel_iter) | |||
| data_file_nolabel, ( | |||
| batch_nolabel, | |||
| batch_size_nolabel) = next(train_nolabel_iter_loop) | |||
| batch_list.append(batch_nolabel) | |||
| batch_size_list.append(batch_size_nolabel) | |||
| with_label_list.append(False) | |||
| data_file_list.append(data_file_nolabel) | |||
| # forward labeled batch and unlabeled batch and collect outputs, respectively | |||
| for (batch, batch_size, with_label, data_file) in \ | |||
| zip(batch_list, batch_size_list, with_label_list, data_file_list): | |||
| batch = type(batch)( | |||
| map(lambda kv: (kv[0], self.to_tensor(kv[1])), | |||
| batch.items())) | |||
| if self.example and with_label: | |||
| current_dataset = train_label_iter.data_file_to_dataset[ | |||
| data_file] | |||
| example_batch = self.reader.retrieve_examples( | |||
| dataset=current_dataset, | |||
| labels=batch['intent_label'], | |||
| inds=batch['ids'], | |||
| task='intent') | |||
| example_batch = type(example_batch)( | |||
| map(lambda kv: (kv[0], self.to_tensor(kv[1])), | |||
| example_batch.items())) | |||
| for k, v in example_batch.items(): | |||
| batch[k] = v | |||
| batch['epoch'] = self.epoch | |||
| batch['num_steps'] = self.batch_num | |||
| metrics = self.model( | |||
| batch, | |||
| is_training=True, | |||
| with_label=with_label, | |||
| data_file=data_file) | |||
| loss, metrics = self.balance_metrics( | |||
| metrics=metrics, batch_size=batch_size) | |||
| loss_list.append(loss) | |||
| metrics_list.append(metrics) | |||
| # combine loss for labeled data and unlabeled data | |||
| # TODO change the computation of combined loss of labeled batch and unlabeled batch | |||
| loss = loss_list[0] if len( | |||
| loss_list) == 1 else loss_list[0] + loss_list[1] | |||
| # optimization procedure | |||
| self.func_model._optimize( | |||
| loss, optimizer=self.optimizer, lr_scheduler=self.lr_scheduler) | |||
| elapsed = time.time() - start_time | |||
| times.append(elapsed) | |||
| self.batch_num += 1 | |||
| # track metrics and log temporary message | |||
| for (batch_size, metrics, | |||
| with_label) in zip(batch_size_list, metrics_list, | |||
| with_label_list): | |||
| self.track_and_log_message( | |||
| metrics=metrics, | |||
| batch_id=batch_id, | |||
| batch_size=batch_size, | |||
| num_batches=num_batches, | |||
| times=times, | |||
| with_label=with_label) | |||
| # evaluate | |||
| if self.valid_steps > 0 and valid_label_iter is not None and valid_nolabel_iter is not None \ | |||
| and batch_id % self.valid_steps == 0: | |||
| self.evaluate( | |||
| data_label_iter=valid_label_iter, | |||
| data_nolabel_iter=valid_nolabel_iter) | |||
| # compute accuracy for valid dataset | |||
| accuracy = self.infer( | |||
| data_iter=valid_label_iter, ex_data_iter=train_label_iter) | |||
| # report summary message and save checkpoints | |||
| self.save_and_log_message( | |||
| report_for_unlabeled_data, cur_valid_metric=-accuracy) | |||
| def forward(self, batch): | |||
| pred = [] | |||
| with torch.no_grad(): | |||
| batch = type(batch)( | |||
| map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items())) | |||
| result = self.model.infer(inputs=batch) | |||
| result = { | |||
| name: result[name].cpu().detach().numpy() | |||
| for name in result | |||
| } | |||
| intent_probs = result['intent_probs'] | |||
| if self.can_norm: | |||
| pred += [intent_probs] | |||
| else: | |||
| pred += np.argmax(intent_probs, axis=1).tolist() | |||
| return pred | |||
| def infer(self, data_iter, num_batches=None, ex_data_iter=None): | |||
| """ | |||
| Inference interface. | |||
| """ | |||
| self.logger.info('Generation starts ...') | |||
| infer_save_file = os.path.join(self.save_dir, | |||
| f'infer_{self.epoch}.result.json') | |||
| # Inference | |||
| batch_cnt = 0 | |||
| pred, true = [], [] | |||
| outputs, labels = [], [] | |||
| begin_time = time.time() | |||
| with torch.no_grad(): | |||
| if self.example: | |||
| for _, (batch, batch_size) in tqdm( | |||
| ex_data_iter, desc='Building train memory.'): | |||
| batch = type(batch)( | |||
| map(lambda kv: (kv[0], self.to_tensor(kv[1])), | |||
| batch.items())) | |||
| result = self.model.infer(inputs=batch) | |||
| result = { | |||
| name: result[name].cpu().detach().numpy() | |||
| for name in result | |||
| } | |||
| outputs.append(torch.from_numpy(result['features'])) | |||
| labels += batch['intent_label'].tolist() | |||
| mem = torch.cat(outputs, dim=0) | |||
| mem = mem.cuda() if self.func_model.use_gpu else mem | |||
| labels = torch.LongTensor(labels).unsqueeze(0) | |||
| labels = labels.cuda() if self.func_model.use_gpu else labels | |||
| self.logger.info(f'Memory size: {mem.size()}') | |||
| for _, (batch, batch_size) in tqdm(data_iter, total=num_batches): | |||
| batch = type(batch)( | |||
| map(lambda kv: (kv[0], self.to_tensor(kv[1])), | |||
| batch.items())) | |||
| result = self.model.infer(inputs=batch) | |||
| result = { | |||
| name: result[name].cpu().detach().numpy() | |||
| for name in result | |||
| } | |||
| if self.example: | |||
| features = torch.from_numpy(result['features']) | |||
| features = features.cuda( | |||
| ) if self.func_model.use_gpu else features | |||
| probs = torch.softmax(features.mm(mem.t()), dim=-1) | |||
| intent_probs = torch.zeros( | |||
| probs.size(0), self.func_model.num_intent) | |||
| intent_probs = intent_probs.cuda( | |||
| ) if self.func_model.use_gpu else intent_probs | |||
| intent_probs = intent_probs.scatter_add( | |||
| -1, labels.repeat(probs.size(0), 1), probs) | |||
| intent_probs = intent_probs.cpu().detach().numpy() | |||
| else: | |||
| intent_probs = result['intent_probs'] | |||
| if self.can_norm: | |||
| pred += [intent_probs] | |||
| true += batch['intent_label'].cpu().detach().tolist() | |||
| else: | |||
| pred += np.argmax(intent_probs, axis=1).tolist() | |||
| true += batch['intent_label'].cpu().detach().tolist() | |||
| batch_cnt += 1 | |||
| if batch_cnt == num_batches: | |||
| break | |||
| if self.can_norm: | |||
| true = np.array(true) | |||
| pred = np.concatenate(pred, axis=0) | |||
| acc_original, acc_final, message = self.can_normalization( | |||
| y_pred=pred, y_true=true, ex_data_iter=ex_data_iter) | |||
| accuracy = max(acc_original, acc_final) | |||
| infer_results = { | |||
| 'accuracy': accuracy, | |||
| 'pred_labels': pred.tolist(), | |||
| 'message': message | |||
| } | |||
| metrics_message = f'Accuracy: {accuracy} {message}' | |||
| else: | |||
| accuracy = sum(p == t for p, t in zip(pred, true)) / len(pred) | |||
| infer_results = {'accuracy': accuracy, 'pred_labels': pred} | |||
| metrics_message = f'Accuracy: {accuracy}' | |||
| self.logger.info(f'Saved inference results to {infer_save_file}') | |||
| with open(infer_save_file, 'w') as fp: | |||
| json.dump(infer_results, fp, indent=2) | |||
| message_prefix = f'[Infer][{self.epoch}]' | |||
| time_cost = f'TIME-{time.time() - begin_time:.3f}' | |||
| message = ' '.join([message_prefix, metrics_message, time_cost]) | |||
| self.logger.info(message) | |||
| return accuracy | |||
| def track_and_log_message(self, metrics, batch_id, batch_size, num_batches, | |||
| times, with_label): | |||
| # track metrics | |||
| batch_metrics_tracker = self.batch_metrics_tracker_label if with_label else self.batch_metrics_tracker_nolabel | |||
| token_metrics_tracker = self.token_metrics_tracker_label if with_label else self.token_metrics_tracker_nolabel | |||
| metrics = { | |||
| k: v.cpu().detach().numpy() if isinstance(v, torch.Tensor) else v | |||
| for k, v in metrics.items() | |||
| } | |||
| mlm_num = metrics.pop('mlm_num', 0) | |||
| batch_metrics = {k: v for k, v in metrics.items() if 'token' not in k} | |||
| token_metrics = {k: v for k, v in metrics.items() if 'token' in k} | |||
| batch_metrics_tracker.update(batch_metrics, batch_size) | |||
| token_metrics_tracker.update(token_metrics, mlm_num) | |||
| # log message | |||
| if self.log_steps > 0 and batch_id % self.log_steps == 0: | |||
| batch_metrics_message = batch_metrics_tracker.value() | |||
| token_metrics_message = token_metrics_tracker.value() | |||
| label_prefix = 'Labeled' if with_label else 'Unlabeled' | |||
| message_prefix = f'[Train][{self.epoch}][{batch_id}/{num_batches}][{label_prefix}]' | |||
| avg_time = f'AVG_Time-{sum(times[-self.log_steps:]) / self.log_steps:.3f}' | |||
| message = ' '.join([ | |||
| message_prefix, batch_metrics_message, token_metrics_message, | |||
| avg_time | |||
| ]) | |||
| self.logger.info(message) | |||
| def save_and_log_message(self, | |||
| report_for_unlabeled_data, | |||
| cur_valid_metric=None): | |||
| # report message | |||
| batch_metrics_message = self.batch_metrics_tracker_label.summary() | |||
| token_metrics_message = self.token_metrics_tracker_label.summary() | |||
| message_prefix = f'[Valid][{self.epoch}][Labeled]' | |||
| message = ' '.join( | |||
| [message_prefix, batch_metrics_message, token_metrics_message]) | |||
| self.logger.info(message) | |||
| if report_for_unlabeled_data: | |||
| batch_metrics_message = self.batch_metrics_tracker_nolabel.summary( | |||
| ) | |||
| token_metrics_message = self.token_metrics_tracker_nolabel.summary( | |||
| ) | |||
| message_prefix = f'[Valid][{self.epoch}][Unlabeled]' | |||
| message = ' '.join( | |||
| [message_prefix, batch_metrics_message, token_metrics_message]) | |||
| self.logger.info(message) | |||
| # save checkpoints | |||
| assert cur_valid_metric is not None | |||
| if self.is_decreased_valid_metric: | |||
| is_best = cur_valid_metric < self.best_valid_metric | |||
| else: | |||
| is_best = cur_valid_metric > self.best_valid_metric | |||
| if is_best: | |||
| self.best_valid_metric = cur_valid_metric | |||
| self.save(is_best) | |||
| def balance_metrics(self, metrics, batch_size): | |||
| if self.gpu > 1: | |||
| for metric in metrics: | |||
| if metric is not None: | |||
| assert len(metric) == self.gpu | |||
| intent_loss, mlm, token_mlm, mlm_num, kl, con = metrics | |||
| metrics = {} | |||
| intent_loss = torch.mean(intent_loss) | |||
| metrics['intent_loss'] = intent_loss | |||
| loss = intent_loss | |||
| if mlm is not None: | |||
| mlm_num = torch.sum(mlm_num) | |||
| token_mlm = torch.sum(mlm) * (batch_size / self.gpu) / mlm_num | |||
| mlm = torch.mean(mlm) | |||
| metrics['mlm_num'] = mlm_num | |||
| metrics['token_mlm'] = token_mlm | |||
| metrics['mlm'] = mlm | |||
| loss = loss + (token_mlm if self.func_model.token_loss else | |||
| mlm) * self.func_model.mlm_ratio | |||
| if kl is not None: | |||
| kl = torch.mean(kl) | |||
| metrics['kl'] = kl | |||
| loss = loss + kl * self.func_model.kl_ratio | |||
| if con is not None: | |||
| con = torch.mean(con) | |||
| metrics['con'] = con | |||
| loss = loss + con | |||
| metrics['loss'] = loss | |||
| assert 'loss' in metrics | |||
| return metrics['loss'], metrics | |||
| def load(self): | |||
| """ load """ | |||
| def _load_model_state(): | |||
| model_state_dict = torch.load( | |||
| f'{self.func_model.init_checkpoint}', | |||
| map_location=lambda storage, loc: storage) | |||
| if 'module.' in list(model_state_dict.keys())[0]: | |||
| new_model_state_dict = OrderedDict() | |||
| for k, v in model_state_dict.items(): | |||
| assert k[:7] == 'module.' | |||
| new_model_state_dict[k[7:]] = v | |||
| model_state_dict = new_model_state_dict | |||
| new_model_state_dict = OrderedDict() | |||
| parameters = { | |||
| name: param | |||
| for name, param in self.func_model.named_parameters() | |||
| } | |||
| for name, param in model_state_dict.items(): | |||
| if name in parameters: | |||
| if param.shape != parameters[name].shape: | |||
| assert hasattr(param, 'numpy') | |||
| arr = param.numpy() | |||
| z = np.random.normal( | |||
| scale=self.func_model.initializer_range, | |||
| size=parameters[name].shape).astype('float32') | |||
| if name == 'embedder.token_embedding.weight': | |||
| z[-param.shape[0]:] = arr | |||
| print( | |||
| f'part of parameter({name}) random normlize initialize' | |||
| ) | |||
| else: | |||
| if z.shape[0] < param.shape[0]: | |||
| z = arr[:z.shape[0]] | |||
| print(f'part of parameter({name}) are dropped') | |||
| else: | |||
| z[:param.shape[0]] = arr | |||
| print( | |||
| f'part of parameter({name}) random normlize initialize' | |||
| ) | |||
| dtype, device = param.dtype, param.device | |||
| z = torch.tensor(z, dtype=dtype, device=device) | |||
| new_model_state_dict[name] = z | |||
| else: | |||
| new_model_state_dict[name] = param | |||
| else: | |||
| print(f'parameter({name}) are dropped') | |||
| model_state_dict = new_model_state_dict | |||
| for name in parameters: | |||
| if name not in model_state_dict: | |||
| if parameters[name].requires_grad: | |||
| print(f'parameter({name}) random normlize initialize') | |||
| z = np.random.normal( | |||
| scale=self.func_model.initializer_range, | |||
| size=parameters[name].shape).astype('float32') | |||
| dtype, device = parameters[name].dtype, parameters[ | |||
| name].device | |||
| model_state_dict[name] = torch.tensor( | |||
| z, dtype=dtype, device=device) | |||
| else: | |||
| model_state_dict[name] = parameters[name] | |||
| self.func_model.load_state_dict(model_state_dict) | |||
| self.logger.info( | |||
| f"Loaded model state from '{self.func_model.init_checkpoint}.model'" | |||
| ) | |||
| def _load_train_state(): | |||
| train_file = f'{self.func_model.init_checkpoint}.train' | |||
| if os.path.exists(train_file): | |||
| train_state_dict = torch.load( | |||
| train_file, map_location=lambda storage, loc: storage) | |||
| self.epoch = train_state_dict['epoch'] | |||
| self.best_valid_metric = train_state_dict['best_valid_metric'] | |||
| if self.optimizer is not None and 'optimizer' in train_state_dict: | |||
| self.optimizer.load_state_dict( | |||
| train_state_dict['optimizer']) | |||
| if self.lr_scheduler is not None and 'lr_scheduler' in train_state_dict: | |||
| self.lr_scheduler.load_state_dict( | |||
| train_state_dict['lr_scheduler']) | |||
| self.logger.info( | |||
| f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " | |||
| f'best_valid_metric={self.best_valid_metric:.3f})') | |||
| else: | |||
| self.logger.info('Loaded no train state') | |||
| if self.func_model.init_checkpoint is None: | |||
| self.logger.info('Loaded no model !!!') | |||
| return | |||
| if self.do_train: | |||
| _load_model_state() | |||
| return | |||
| if self.do_infer: | |||
| _load_model_state() | |||
| _load_train_state() | |||
| @@ -34,6 +34,8 @@ class Tasks(object): | |||
| # nlp tasks | |||
| word_segmentation = 'word-segmentation' | |||
| nli = 'nli' | |||
| sentiment_classification = 'sentiment-classification' | |||
| sentiment_analysis = 'sentiment-analysis' | |||
| sentence_similarity = 'sentence-similarity' | |||
| text_classification = 'text-classification' | |||
| @@ -43,6 +45,8 @@ class Tasks(object): | |||
| token_classification = 'token-classification' | |||
| conversational = 'conversational' | |||
| text_generation = 'text-generation' | |||
| dialog_modeling = 'dialog-modeling' | |||
| dialog_intent_prediction = 'dialog-intent-prediction' | |||
| table_question_answering = 'table-question-answering' | |||
| feature_extraction = 'feature-extraction' | |||
| fill_mask = 'fill-mask' | |||
| @@ -0,0 +1,66 @@ | |||
| """ | |||
| Parse argument. | |||
| """ | |||
| import argparse | |||
| import json | |||
| def str2bool(v): | |||
| if v.lower() in ('yes', 'true', 't', 'y', '1'): | |||
| return True | |||
| elif v.lower() in ('no', 'false', 'f', 'n', '0'): | |||
| return False | |||
| else: | |||
| raise argparse.ArgumentTypeError('Unsupported value encountered.') | |||
| class HParams(dict): | |||
| """ Hyper-parameters class | |||
| Store hyper-parameters in training / infer / ... scripts. | |||
| """ | |||
| def __getattr__(self, name): | |||
| if name in self.keys(): | |||
| return self[name] | |||
| for v in self.values(): | |||
| if isinstance(v, HParams): | |||
| if name in v: | |||
| return v[name] | |||
| raise AttributeError(f"'HParams' object has no attribute '{name}'") | |||
| def __setattr__(self, name, value): | |||
| self[name] = value | |||
| def save(self, filename): | |||
| with open(filename, 'w', encoding='utf-8') as fp: | |||
| json.dump(self, fp, ensure_ascii=False, indent=4, sort_keys=False) | |||
| def load(self, filename): | |||
| with open(filename, 'r', encoding='utf-8') as fp: | |||
| params_dict = json.load(fp) | |||
| for k, v in params_dict.items(): | |||
| if isinstance(v, dict): | |||
| self[k].update(HParams(v)) | |||
| else: | |||
| self[k] = v | |||
| def parse_args(parser): | |||
| """ Parse hyper-parameters from cmdline. """ | |||
| parsed = parser.parse_args() | |||
| args = HParams() | |||
| optional_args = parser._action_groups[1] | |||
| for action in optional_args._group_actions[1:]: | |||
| arg_name = action.dest | |||
| args[arg_name] = getattr(parsed, arg_name) | |||
| for group in parser._action_groups[2:]: | |||
| group_args = HParams() | |||
| for action in group._group_actions: | |||
| arg_name = action.dest | |||
| group_args[arg_name] = getattr(parsed, arg_name) | |||
| if len(group_args) > 0: | |||
| args[group.title] = group_args | |||
| return args | |||
| @@ -0,0 +1,52 @@ | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from torch.nn.modules.loss import _Loss | |||
| def compute_kl_loss(p, q, filter_scores=None): | |||
| p_loss = F.kl_div( | |||
| F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none') | |||
| q_loss = F.kl_div( | |||
| F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none') | |||
| # You can choose whether to use function "sum" and "mean" depending on your task | |||
| p_loss = p_loss.sum(dim=-1) | |||
| q_loss = q_loss.sum(dim=-1) | |||
| # mask is for filter mechanism | |||
| if filter_scores is not None: | |||
| p_loss = filter_scores * p_loss | |||
| q_loss = filter_scores * q_loss | |||
| p_loss = p_loss.mean() | |||
| q_loss = q_loss.mean() | |||
| loss = (p_loss + q_loss) / 2 | |||
| return loss | |||
| class CatKLLoss(_Loss): | |||
| """ | |||
| CatKLLoss | |||
| """ | |||
| def __init__(self, reduction='mean'): | |||
| super(CatKLLoss, self).__init__() | |||
| assert reduction in ['none', 'sum', 'mean'] | |||
| self.reduction = reduction | |||
| def forward(self, log_qy, log_py): | |||
| """ | |||
| KL(qy|py) = Eq[qy * log(q(y) / p(y))] | |||
| log_qy: (batch_size, latent_size) | |||
| log_py: (batch_size, latent_size) | |||
| """ | |||
| qy = torch.exp(log_qy) | |||
| kl = torch.sum(qy * (log_qy - log_py), dim=1) | |||
| if self.reduction == 'mean': | |||
| kl = kl.mean() | |||
| elif self.reduction == 'sum': | |||
| kl = kl.sum() | |||
| return kl | |||
| @@ -0,0 +1,313 @@ | |||
| import os | |||
| import random | |||
| import sqlite3 | |||
| import json | |||
| from .ontology import all_domains, db_domains | |||
| class MultiWozDB(object): | |||
| def __init__(self, db_dir, db_paths): | |||
| self.dbs = {} | |||
| self.sql_dbs = {} | |||
| for domain in all_domains: | |||
| with open(os.path.join(db_dir, db_paths[domain]), 'r') as f: | |||
| self.dbs[domain] = json.loads(f.read().lower()) | |||
| def oneHotVector(self, domain, num): | |||
| """Return number of available entities for particular domain.""" | |||
| vector = [0, 0, 0, 0] | |||
| if num == '': | |||
| return vector | |||
| if domain != 'train': | |||
| if num == 0: | |||
| vector = [1, 0, 0, 0] | |||
| elif num == 1: | |||
| vector = [0, 1, 0, 0] | |||
| elif num <= 3: | |||
| vector = [0, 0, 1, 0] | |||
| else: | |||
| vector = [0, 0, 0, 1] | |||
| else: | |||
| if num == 0: | |||
| vector = [1, 0, 0, 0] | |||
| elif num <= 5: | |||
| vector = [0, 1, 0, 0] | |||
| elif num <= 10: | |||
| vector = [0, 0, 1, 0] | |||
| else: | |||
| vector = [0, 0, 0, 1] | |||
| return vector | |||
| def addBookingPointer(self, turn_da): | |||
| """Add information about availability of the booking option.""" | |||
| # Booking pointer | |||
| # Do not consider booking two things in a single turn. | |||
| vector = [0, 0] | |||
| if turn_da.get('booking-nobook'): | |||
| vector = [1, 0] | |||
| if turn_da.get('booking-book') or turn_da.get('train-offerbooked'): | |||
| vector = [0, 1] | |||
| return vector | |||
| def addDBPointer(self, domain, match_num, return_num=False): | |||
| """Create database pointer for all related domains.""" | |||
| # if turn_domains is None: | |||
| # turn_domains = db_domains | |||
| if domain in db_domains: | |||
| vector = self.oneHotVector(domain, match_num) | |||
| else: | |||
| vector = [0, 0, 0, 0] | |||
| return vector | |||
| def addDBIndicator(self, domain, match_num, return_num=False): | |||
| """Create database indicator for all related domains.""" | |||
| # if turn_domains is None: | |||
| # turn_domains = db_domains | |||
| if domain in db_domains: | |||
| vector = self.oneHotVector(domain, match_num) | |||
| else: | |||
| vector = [0, 0, 0, 0] | |||
| # '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]' | |||
| if vector == [0, 0, 0, 0]: | |||
| indicator = '[db_nores]' | |||
| else: | |||
| indicator = '[db_%s]' % vector.index(1) | |||
| return indicator | |||
| def get_match_num(self, constraints, return_entry=False): | |||
| """Create database pointer for all related domains.""" | |||
| match = {'general': ''} | |||
| entry = {} | |||
| # if turn_domains is None: | |||
| # turn_domains = db_domains | |||
| for domain in all_domains: | |||
| match[domain] = '' | |||
| if domain in db_domains and constraints.get(domain): | |||
| matched_ents = self.queryJsons(domain, constraints[domain]) | |||
| match[domain] = len(matched_ents) | |||
| if return_entry: | |||
| entry[domain] = matched_ents | |||
| if return_entry: | |||
| return entry | |||
| return match | |||
| def pointerBack(self, vector, domain): | |||
| # multi domain implementation | |||
| # domnum = cfg.domain_num | |||
| if domain.endswith(']'): | |||
| domain = domain[1:-1] | |||
| if domain != 'train': | |||
| nummap = {0: '0', 1: '1', 2: '2-3', 3: '>3'} | |||
| else: | |||
| nummap = {0: '0', 1: '1-5', 2: '6-10', 3: '>10'} | |||
| if vector[:4] == [0, 0, 0, 0]: | |||
| report = '' | |||
| else: | |||
| num = vector.index(1) | |||
| report = domain + ': ' + nummap[num] + '; ' | |||
| if vector[-2] == 0 and vector[-1] == 1: | |||
| report += 'booking: ok' | |||
| if vector[-2] == 1 and vector[-1] == 0: | |||
| report += 'booking: unable' | |||
| return report | |||
| def queryJsons(self, | |||
| domain, | |||
| constraints, | |||
| exactly_match=True, | |||
| return_name=False): | |||
| """Returns the list of entities for a given domain | |||
| based on the annotation of the belief state | |||
| constraints: dict e.g. {'pricerange': 'cheap', 'area': 'west'} | |||
| """ | |||
| # query the db | |||
| if domain == 'taxi': | |||
| return [{ | |||
| 'taxi_colors': | |||
| random.choice(self.dbs[domain]['taxi_colors']), | |||
| 'taxi_types': | |||
| random.choice(self.dbs[domain]['taxi_types']), | |||
| 'taxi_phone': [random.randint(1, 9) for _ in range(10)] | |||
| }] | |||
| if domain == 'police': | |||
| return self.dbs['police'] | |||
| if domain == 'hospital': | |||
| if constraints.get('department'): | |||
| for entry in self.dbs['hospital']: | |||
| if entry.get('department') == constraints.get( | |||
| 'department'): | |||
| return [entry] | |||
| else: | |||
| return [] | |||
| valid_cons = False | |||
| for v in constraints.values(): | |||
| if v not in ['not mentioned', '']: | |||
| valid_cons = True | |||
| if not valid_cons: | |||
| return [] | |||
| match_result = [] | |||
| if 'name' in constraints: | |||
| for db_ent in self.dbs[domain]: | |||
| if 'name' in db_ent: | |||
| cons = constraints['name'] | |||
| dbn = db_ent['name'] | |||
| if cons == dbn: | |||
| db_ent = db_ent if not return_name else db_ent['name'] | |||
| match_result.append(db_ent) | |||
| return match_result | |||
| for db_ent in self.dbs[domain]: | |||
| match = True | |||
| for s, v in constraints.items(): | |||
| if s == 'name': | |||
| continue | |||
| if s in ['people', 'stay'] or (domain == 'hotel' and s == 'day') or \ | |||
| (domain == 'restaurant' and s in ['day', 'time']): | |||
| # These inform slots belong to "book info",which do not exist in DB | |||
| # "book" is according to the user goal,not DB | |||
| continue | |||
| skip_case = { | |||
| "don't care": 1, | |||
| "do n't care": 1, | |||
| 'dont care': 1, | |||
| 'not mentioned': 1, | |||
| 'dontcare': 1, | |||
| '': 1 | |||
| } | |||
| if skip_case.get(v): | |||
| continue | |||
| if s not in db_ent: | |||
| # logging.warning('Searching warning: slot %s not in %s db'%(s, domain)) | |||
| match = False | |||
| break | |||
| # v = 'guesthouse' if v == 'guest house' else v | |||
| # v = 'swimmingpool' if v == 'swimming pool' else v | |||
| v = 'yes' if v == 'free' else v | |||
| if s in ['arrive', 'leave']: | |||
| try: | |||
| h, m = v.split( | |||
| ':' | |||
| ) # raise error if time value is not xx:xx format | |||
| v = int(h) * 60 + int(m) | |||
| except Exception: | |||
| match = False | |||
| break | |||
| time = int(db_ent[s].split(':')[0]) * 60 + int( | |||
| db_ent[s].split(':')[1]) | |||
| if s == 'arrive' and v > time: | |||
| match = False | |||
| if s == 'leave' and v < time: | |||
| match = False | |||
| else: | |||
| if exactly_match and v != db_ent[s]: | |||
| match = False | |||
| break | |||
| elif v not in db_ent[s]: | |||
| match = False | |||
| break | |||
| if match: | |||
| match_result.append(db_ent) | |||
| if not return_name: | |||
| return match_result | |||
| else: | |||
| if domain == 'train': | |||
| match_result = [e['id'] for e in match_result] | |||
| else: | |||
| match_result = [e['name'] for e in match_result] | |||
| return match_result | |||
| def querySQL(self, domain, constraints): | |||
| if not self.sql_dbs: | |||
| for dom in db_domains: | |||
| db = 'db/{}-dbase.db'.format(dom) | |||
| conn = sqlite3.connect(db) | |||
| c = conn.cursor() | |||
| self.sql_dbs[dom] = c | |||
| sql_query = 'select * from {}'.format(domain) | |||
| flag = True | |||
| for key, val in constraints.items(): | |||
| if val == '' \ | |||
| or val == 'dontcare' \ | |||
| or val == 'not mentioned' \ | |||
| or val == "don't care" \ | |||
| or val == 'dont care' \ | |||
| or val == "do n't care": | |||
| pass | |||
| else: | |||
| if flag: | |||
| sql_query += ' where ' | |||
| val2 = val.replace("'", "''") | |||
| # val2 = normalize(val2) | |||
| if key == 'leaveAt': | |||
| sql_query += r' ' + key + ' > ' + r"'" + val2 + r"'" | |||
| elif key == 'arriveBy': | |||
| sql_query += r' ' + key + ' < ' + r"'" + val2 + r"'" | |||
| else: | |||
| sql_query += r' ' + key + '=' + r"'" + val2 + r"'" | |||
| flag = False | |||
| else: | |||
| val2 = val.replace("'", "''") | |||
| # val2 = normalize(val2) | |||
| if key == 'leaveAt': | |||
| sql_query += r' and ' + key + ' > ' + r"'" + val2 + r"'" | |||
| elif key == 'arriveBy': | |||
| sql_query += r' and ' + key + ' < ' + r"'" + val2 + r"'" | |||
| else: | |||
| sql_query += r' and ' + key + '=' + r"'" + val2 + r"'" | |||
| try: # "select * from attraction where name = 'queens college'" | |||
| print(sql_query) | |||
| return self.sql_dbs[domain].execute(sql_query).fetchall() | |||
| except Exception: | |||
| return [] # TODO test it | |||
| if __name__ == '__main__': | |||
| dbPATHs = { | |||
| 'attraction': 'db/attraction_db_processed.json', | |||
| 'hospital': 'db/hospital_db_processed.json', | |||
| 'hotel': 'db/hotel_db_processed.json', | |||
| 'police': 'db/police_db_processed.json', | |||
| 'restaurant': 'db/restaurant_db_processed.json', | |||
| 'taxi': 'db/taxi_db_processed.json', | |||
| 'train': 'db/train_db_processed.json', | |||
| } | |||
| db = MultiWozDB(dbPATHs) | |||
| while True: | |||
| constraints = {} | |||
| inp = input( | |||
| 'input belief state in fomat: domain-slot1=value1;slot2=value2...\n' | |||
| ) | |||
| domain, cons = inp.split('-') | |||
| for sv in cons.split(';'): | |||
| s, v = sv.split('=') | |||
| constraints[s] = v | |||
| # res = db.querySQL(domain, constraints) | |||
| res = db.queryJsons(domain, constraints, return_name=True) | |||
| report = [] | |||
| reidx = { | |||
| 'hotel': 8, | |||
| 'restaurant': 6, | |||
| 'attraction': 5, | |||
| 'train': 1, | |||
| } | |||
| print(constraints) | |||
| print(res) | |||
| print('count:', len(res), '\nnames:', report) | |||
| @@ -0,0 +1,204 @@ | |||
| all_domains = [ | |||
| 'restaurant', 'hotel', 'attraction', 'train', 'taxi', 'police', 'hospital' | |||
| ] | |||
| all_domains_with_bracket = ['[{}]'.format(item) for item in all_domains] | |||
| db_domains = ['restaurant', 'hotel', 'attraction', 'train'] | |||
| placeholder_tokens = [ | |||
| '<go_r>', '<go_b>', '<go_a>', '<go_d>', '<eos_u>', '<eos_r>', '<eos_b>', | |||
| '<eos_a>', '<eos_d>', '<eos_q>', '<sos_u>', '<sos_r>', '<sos_b>', | |||
| '<sos_a>', '<sos_d>', '<sos_q>' | |||
| ] | |||
| normlize_slot_names = { | |||
| 'car type': 'car', | |||
| 'entrance fee': 'price', | |||
| 'duration': 'time', | |||
| 'leaveat': 'leave', | |||
| 'arriveby': 'arrive', | |||
| 'trainid': 'id' | |||
| } | |||
| requestable_slots = { | |||
| 'taxi': ['car', 'phone'], | |||
| 'police': ['postcode', 'address', 'phone'], | |||
| 'hospital': ['address', 'phone', 'postcode'], | |||
| 'hotel': [ | |||
| 'address', 'postcode', 'internet', 'phone', 'parking', 'type', | |||
| 'pricerange', 'stars', 'area', 'reference' | |||
| ], | |||
| 'attraction': | |||
| ['price', 'type', 'address', 'postcode', 'phone', 'area', 'reference'], | |||
| 'train': ['time', 'leave', 'price', 'arrive', 'id', 'reference'], | |||
| 'restaurant': [ | |||
| 'phone', 'postcode', 'address', 'pricerange', 'food', 'area', | |||
| 'reference' | |||
| ] | |||
| } | |||
| all_reqslot = [ | |||
| 'car', 'address', 'postcode', 'phone', 'internet', 'parking', 'type', | |||
| 'pricerange', 'food', 'stars', 'area', 'reference', 'time', 'leave', | |||
| 'price', 'arrive', 'id' | |||
| ] | |||
| informable_slots = { | |||
| 'taxi': ['leave', 'destination', 'departure', 'arrive'], | |||
| 'police': [], | |||
| 'hospital': ['department'], | |||
| 'hotel': [ | |||
| 'type', 'parking', 'pricerange', 'internet', 'stay', 'day', 'people', | |||
| 'area', 'stars', 'name' | |||
| ], | |||
| 'attraction': ['area', 'type', 'name'], | |||
| 'train': ['destination', 'day', 'arrive', 'departure', 'people', 'leave'], | |||
| 'restaurant': | |||
| ['food', 'pricerange', 'area', 'name', 'time', 'day', 'people'] | |||
| } | |||
| all_infslot = [ | |||
| 'type', 'parking', 'pricerange', 'internet', 'stay', 'day', 'people', | |||
| 'area', 'stars', 'name', 'leave', 'destination', 'departure', 'arrive', | |||
| 'department', 'food', 'time' | |||
| ] | |||
| all_slots = all_reqslot + [ | |||
| 'stay', 'day', 'people', 'name', 'destination', 'departure', 'department' | |||
| ] | |||
| get_slot = {} | |||
| for s in all_slots: | |||
| get_slot[s] = 1 | |||
| # mapping slots in dialogue act to original goal slot names | |||
| da_abbr_to_slot_name = { | |||
| 'addr': 'address', | |||
| 'fee': 'price', | |||
| 'post': 'postcode', | |||
| 'ref': 'reference', | |||
| 'ticket': 'price', | |||
| 'depart': 'departure', | |||
| 'dest': 'destination', | |||
| } | |||
| dialog_acts = { | |||
| 'restaurant': [ | |||
| 'inform', 'request', 'nooffer', 'recommend', 'select', 'offerbook', | |||
| 'offerbooked', 'nobook' | |||
| ], | |||
| 'hotel': [ | |||
| 'inform', 'request', 'nooffer', 'recommend', 'select', 'offerbook', | |||
| 'offerbooked', 'nobook' | |||
| ], | |||
| 'attraction': ['inform', 'request', 'nooffer', 'recommend', 'select'], | |||
| 'train': | |||
| ['inform', 'request', 'nooffer', 'offerbook', 'offerbooked', 'select'], | |||
| 'taxi': ['inform', 'request'], | |||
| 'police': ['inform', 'request'], | |||
| 'hospital': ['inform', 'request'], | |||
| # 'booking': ['book', 'inform', 'nobook', 'request'], | |||
| 'general': ['bye', 'greet', 'reqmore', 'welcome'], | |||
| } | |||
| all_acts = [] | |||
| for acts in dialog_acts.values(): | |||
| for act in acts: | |||
| if act not in all_acts: | |||
| all_acts.append(act) | |||
| dialog_act_params = { | |||
| 'inform': all_slots + ['choice', 'open'], | |||
| 'request': all_infslot + ['choice', 'price'], | |||
| 'nooffer': all_slots + ['choice'], | |||
| 'recommend': all_reqslot + ['choice', 'open'], | |||
| 'select': all_slots + ['choice'], | |||
| # 'book': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'], | |||
| 'nobook': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'], | |||
| 'offerbook': all_slots + ['choice'], | |||
| 'offerbooked': all_slots + ['choice'], | |||
| 'reqmore': [], | |||
| 'welcome': [], | |||
| 'bye': [], | |||
| 'greet': [], | |||
| } | |||
| dialog_act_all_slots = all_slots + ['choice', 'open'] | |||
| # special slot tokens in belief span | |||
| # no need of this, just covert slot to [slot] e.g. pricerange -> [pricerange] | |||
| slot_name_to_slot_token = {} | |||
| # eos tokens definition | |||
| eos_tokens = { | |||
| 'user': '<eos_u>', | |||
| 'user_delex': '<eos_u>', | |||
| 'resp': '<eos_r>', | |||
| 'resp_gen': '<eos_r>', | |||
| 'pv_resp': '<eos_r>', | |||
| 'bspn': '<eos_b>', | |||
| 'bspn_gen': '<eos_b>', | |||
| 'pv_bspn': '<eos_b>', | |||
| 'bsdx': '<eos_b>', | |||
| 'bsdx_gen': '<eos_b>', | |||
| 'pv_bsdx': '<eos_b>', | |||
| 'qspn': '<eos_q>', | |||
| 'qspn_gen': '<eos_q>', | |||
| 'pv_qspn': '<eos_q>', | |||
| 'aspn': '<eos_a>', | |||
| 'aspn_gen': '<eos_a>', | |||
| 'pv_aspn': '<eos_a>', | |||
| 'dspn': '<eos_d>', | |||
| 'dspn_gen': '<eos_d>', | |||
| 'pv_dspn': '<eos_d>' | |||
| } | |||
| # sos tokens definition | |||
| sos_tokens = { | |||
| 'user': '<sos_u>', | |||
| 'user_delex': '<sos_u>', | |||
| 'resp': '<sos_r>', | |||
| 'resp_gen': '<sos_r>', | |||
| 'pv_resp': '<sos_r>', | |||
| 'bspn': '<sos_b>', | |||
| 'bspn_gen': '<sos_b>', | |||
| 'pv_bspn': '<sos_b>', | |||
| 'bsdx': '<sos_b>', | |||
| 'bsdx_gen': '<sos_b>', | |||
| 'pv_bsdx': '<sos_b>', | |||
| 'qspn': '<sos_q>', | |||
| 'qspn_gen': '<sos_q>', | |||
| 'pv_qspn': '<sos_q>', | |||
| 'aspn': '<sos_a>', | |||
| 'aspn_gen': '<sos_a>', | |||
| 'pv_aspn': '<sos_a>', | |||
| 'dspn': '<sos_d>', | |||
| 'dspn_gen': '<sos_d>', | |||
| 'pv_dspn': '<sos_d>' | |||
| } | |||
| # db tokens definition | |||
| db_tokens = [ | |||
| '<sos_db>', '<eos_db>', '[book_nores]', '[book_fail]', '[book_success]', | |||
| '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]' | |||
| ] | |||
| # understand tokens definition | |||
| def get_understand_tokens(prompt_num_for_understand): | |||
| understand_tokens = [] | |||
| for i in range(prompt_num_for_understand): | |||
| understand_tokens.append(f'<understand_{i}>') | |||
| return understand_tokens | |||
| # policy tokens definition | |||
| def get_policy_tokens(prompt_num_for_policy): | |||
| policy_tokens = [] | |||
| for i in range(prompt_num_for_policy): | |||
| policy_tokens.append(f'<policy_{i}>') | |||
| return policy_tokens | |||
| # all special tokens definition | |||
| def get_special_tokens(other_tokens): | |||
| special_tokens = [ | |||
| '<go_r>', '<go_b>', '<go_a>', '<go_d>', '<eos_u>', '<eos_r>', | |||
| '<eos_b>', '<eos_a>', '<eos_d>', '<eos_q>', '<sos_u>', '<sos_r>', | |||
| '<sos_b>', '<sos_a>', '<sos_d>', '<sos_q>' | |||
| ] + db_tokens + other_tokens | |||
| return special_tokens | |||
| @@ -0,0 +1,6 @@ | |||
| def hierarchical_set_score(frame1, frame2): | |||
| # deal with empty frame | |||
| if not (frame1 and frame2): | |||
| return 0. | |||
| pass | |||
| return 0. | |||
| @@ -0,0 +1,188 @@ | |||
| import logging | |||
| from collections import OrderedDict | |||
| import json | |||
| import numpy as np | |||
| from . import ontology | |||
| def max_lens(X): | |||
| lens = [len(X)] | |||
| while isinstance(X[0], list): | |||
| lens.append(max(map(len, X))) | |||
| X = [x for xs in X for x in xs] | |||
| return lens | |||
| def list2np(X: object, padding: object = 0, dtype: object = 'int64') -> object: | |||
| shape = max_lens(X) | |||
| ret = np.full(shape, padding, dtype=np.int32) | |||
| if len(shape) == 1: | |||
| ret = np.array(X) | |||
| elif len(shape) == 2: | |||
| for i, x in enumerate(X): | |||
| ret[i, :len(x)] = np.array(x) | |||
| elif len(shape) == 3: | |||
| for i, xs in enumerate(X): | |||
| for j, x in enumerate(xs): | |||
| ret[i, j, :len(x)] = np.array(x) | |||
| return ret.astype(dtype) | |||
| def clean_replace(s, r, t, forward=True, backward=False): | |||
| def clean_replace_single(s, r, t, forward, backward, sidx=0): | |||
| # idx = s[sidx:].find(r) | |||
| idx = s.find(r) | |||
| if idx == -1: | |||
| return s, -1 | |||
| idx_r = idx + len(r) | |||
| if backward: | |||
| while idx > 0 and s[idx - 1]: | |||
| idx -= 1 | |||
| elif idx > 0 and s[idx - 1] != ' ': | |||
| return s, -1 | |||
| if forward: | |||
| while \ | |||
| idx_r < len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()): | |||
| idx_r += 1 | |||
| elif idx_r != len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()): | |||
| return s, -1 | |||
| return s[:idx] + t + s[idx_r:], idx_r | |||
| sidx = 0 | |||
| while sidx != -1: | |||
| s, sidx = clean_replace_single(s, r, t, forward, backward, sidx) | |||
| return s | |||
| def py2np(list): | |||
| return np.array(list) | |||
| def write_dict(fn, dic): | |||
| with open(fn, 'w') as f: | |||
| json.dump(dic, f, indent=2) | |||
| def f1_score(label_list, pred_list): | |||
| tp = len([t for t in pred_list if t in label_list]) | |||
| fp = max(0, len(pred_list) - tp) | |||
| fn = max(0, len(label_list) - tp) | |||
| precision = tp / (tp + fp + 1e-10) | |||
| recall = tp / (tp + fn + 1e-10) | |||
| f1 = 2 * precision * recall / (precision + recall + 1e-10) | |||
| return f1 | |||
| class MultiWOZVocab(object): | |||
| def __init__(self, vocab_size=0): | |||
| """ | |||
| vocab for multiwoz dataset | |||
| """ | |||
| self.vocab_size = vocab_size | |||
| self.vocab_size_oov = 0 # get after construction | |||
| self._idx2word = {} # word + oov | |||
| self._word2idx = {} # word | |||
| self._freq_dict = {} # word + oov | |||
| for w in [ | |||
| '[PAD]', '<go_r>', '[UNK]', '<go_b>', '<go_a>', '<eos_u>', | |||
| '<eos_r>', '<eos_b>', '<eos_a>', '<go_d>', '<eos_d>' | |||
| ]: | |||
| self._absolute_add_word(w) | |||
| def _absolute_add_word(self, w): | |||
| idx = len(self._idx2word) | |||
| self._idx2word[idx] = w | |||
| self._word2idx[w] = idx | |||
| def add_word(self, word): | |||
| if word not in self._freq_dict: | |||
| self._freq_dict[word] = 0 | |||
| self._freq_dict[word] += 1 | |||
| def has_word(self, word): | |||
| return self._freq_dict.get(word) | |||
| def _add_to_vocab(self, word): | |||
| if word not in self._word2idx: | |||
| idx = len(self._idx2word) | |||
| self._idx2word[idx] = word | |||
| self._word2idx[word] = idx | |||
| def construct(self): | |||
| freq_dict_sorted = sorted( | |||
| self._freq_dict.keys(), key=lambda x: -self._freq_dict[x]) | |||
| print('Vocabulary size including oov: %d' % | |||
| (len(freq_dict_sorted) + len(self._idx2word))) | |||
| if len(freq_dict_sorted) + len(self._idx2word) < self.vocab_size: | |||
| logging.warning( | |||
| 'actual label set smaller than that configured: {}/{}'.format( | |||
| len(freq_dict_sorted) + len(self._idx2word), | |||
| self.vocab_size)) | |||
| for word in ontology.all_domains + ['general']: | |||
| word = '[' + word + ']' | |||
| self._add_to_vocab(word) | |||
| for word in ontology.all_acts: | |||
| word = '[' + word + ']' | |||
| self._add_to_vocab(word) | |||
| for word in ontology.all_slots: | |||
| self._add_to_vocab(word) | |||
| for word in freq_dict_sorted: | |||
| if word.startswith('[value_') and word.endswith(']'): | |||
| self._add_to_vocab(word) | |||
| for word in freq_dict_sorted: | |||
| self._add_to_vocab(word) | |||
| self.vocab_size_oov = len(self._idx2word) | |||
| def load_vocab(self, vocab_path): | |||
| self._freq_dict = json.loads( | |||
| open(vocab_path + '.freq.json', 'r').read()) | |||
| self._word2idx = json.loads( | |||
| open(vocab_path + '.word2idx.json', 'r').read()) | |||
| self._idx2word = {} | |||
| for w, idx in self._word2idx.items(): | |||
| self._idx2word[idx] = w | |||
| self.vocab_size_oov = len(self._idx2word) | |||
| print('vocab file loaded from "' + vocab_path + '"') | |||
| print('Vocabulary size including oov: %d' % (self.vocab_size_oov)) | |||
| def save_vocab(self, vocab_path): | |||
| _freq_dict = OrderedDict( | |||
| sorted( | |||
| self._freq_dict.items(), key=lambda kv: kv[1], reverse=True)) | |||
| write_dict(vocab_path + '.word2idx.json', self._word2idx) | |||
| write_dict(vocab_path + '.freq.json', _freq_dict) | |||
| def encode(self, word, include_oov=True): | |||
| if include_oov: | |||
| if self._word2idx.get(word, None) is None: | |||
| raise ValueError( | |||
| 'Unknown word: %s. Vocabulary should include oovs here.' | |||
| % word) | |||
| return self._word2idx[word] | |||
| else: | |||
| word = '<unk>' if word not in self._word2idx else word | |||
| return self._word2idx[word] | |||
| def sentence_encode(self, word_list): | |||
| return [self.encode(_) for _ in word_list] | |||
| def oov_idx_map(self, idx): | |||
| return 2 if idx > self.vocab_size else idx | |||
| def sentence_oov_map(self, index_list): | |||
| return [self.oov_idx_map(_) for _ in index_list] | |||
| def decode(self, idx, indicate_oov=False): | |||
| if not self._idx2word.get(idx): | |||
| raise ValueError( | |||
| 'Error idx: %d. Vocabulary should include oovs here.' % idx) | |||
| if not indicate_oov or idx < self.vocab_size: | |||
| return self._idx2word[idx] | |||
| else: | |||
| return self._idx2word[idx] + '(o)' | |||
| @@ -1 +1,3 @@ | |||
| sofa==1.0.4.2 | |||
| https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz | |||
| sofa==1.0.5 | |||
| spacy>=2.3.5 | |||
| @@ -0,0 +1,61 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import SpaceForDialogIntent | |||
| from modelscope.pipelines import DialogIntentPredictionPipeline, pipeline | |||
| from modelscope.preprocessors import DialogIntentPredictionPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class DialogIntentPredictionTest(unittest.TestCase): | |||
| model_id = 'damo/nlp_space_dialog-intent-prediction' | |||
| test_case = [ | |||
| 'How do I locate my card?', | |||
| 'I still have not received my new card, I ordered over a week ago.' | |||
| ] | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path) | |||
| model = SpaceForDialogIntent( | |||
| model_dir=cache_path, | |||
| text_field=preprocessor.text_field, | |||
| config=preprocessor.config) | |||
| pipelines = [ | |||
| DialogIntentPredictionPipeline( | |||
| model=model, preprocessor=preprocessor), | |||
| pipeline( | |||
| task=Tasks.dialog_intent_prediction, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| ] | |||
| for my_pipeline, item in list(zip(pipelines, self.test_case)): | |||
| print(my_pipeline(item)) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| preprocessor = DialogIntentPredictionPreprocessor( | |||
| model_dir=model.model_dir) | |||
| pipelines = [ | |||
| DialogIntentPredictionPipeline( | |||
| model=model, preprocessor=preprocessor), | |||
| pipeline( | |||
| task=Tasks.dialog_intent_prediction, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| ] | |||
| for my_pipeline, item in list(zip(pipelines, self.test_case)): | |||
| print(my_pipeline(item)) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,147 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import SpaceForDialogModeling | |||
| from modelscope.pipelines import DialogModelingPipeline, pipeline | |||
| from modelscope.preprocessors import DialogModelingPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class DialogModelingTest(unittest.TestCase): | |||
| model_id = 'damo/nlp_space_dialog-modeling' | |||
| test_case = { | |||
| 'sng0073': { | |||
| 'goal': { | |||
| 'taxi': { | |||
| 'info': { | |||
| 'leaveat': '17:15', | |||
| 'destination': 'pizza hut fen ditton', | |||
| 'departure': "saint john's college" | |||
| }, | |||
| 'reqt': ['car', 'phone'], | |||
| 'fail_info': {} | |||
| } | |||
| }, | |||
| 'log': [{ | |||
| 'user': | |||
| "i would like a taxi from saint john 's college to pizza hut fen ditton .", | |||
| 'user_delex': | |||
| 'i would like a taxi from [value_departure] to [value_destination] .', | |||
| 'resp': | |||
| 'what time do you want to leave and what time do you want to arrive by ?', | |||
| 'sys': | |||
| 'what time do you want to leave and what time do you want to arrive by ?', | |||
| 'pointer': '0,0,0,0,0,0', | |||
| 'match': '', | |||
| 'constraint': | |||
| "[taxi] destination pizza hut fen ditton departure saint john 's college", | |||
| 'cons_delex': '[taxi] destination departure', | |||
| 'sys_act': '[taxi] [request] leave arrive', | |||
| 'turn_num': 0, | |||
| 'turn_domain': '[taxi]' | |||
| }, { | |||
| 'user': 'i want to leave after 17:15 .', | |||
| 'user_delex': 'i want to leave after [value_leave] .', | |||
| 'resp': | |||
| 'booking completed ! your taxi will be [value_car] contact number is [value_phone]', | |||
| 'sys': | |||
| 'booking completed ! your taxi will be blue honda contact number is 07218068540', | |||
| 'pointer': '0,0,0,0,0,0', | |||
| 'match': '', | |||
| 'constraint': | |||
| "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", | |||
| 'cons_delex': '[taxi] destination departure leave', | |||
| 'sys_act': '[taxi] [inform] car phone', | |||
| 'turn_num': 1, | |||
| 'turn_domain': '[taxi]' | |||
| }, { | |||
| 'user': 'thank you for all the help ! i appreciate it .', | |||
| 'user_delex': 'thank you for all the help ! i appreciate it .', | |||
| 'resp': | |||
| 'you are welcome . is there anything else i can help you with today ?', | |||
| 'sys': | |||
| 'you are welcome . is there anything else i can help you with today ?', | |||
| 'pointer': '0,0,0,0,0,0', | |||
| 'match': '', | |||
| 'constraint': | |||
| "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", | |||
| 'cons_delex': '[taxi] destination departure leave', | |||
| 'sys_act': '[general] [reqmore]', | |||
| 'turn_num': 2, | |||
| 'turn_domain': '[general]' | |||
| }, { | |||
| 'user': 'no , i am all set . have a nice day . bye .', | |||
| 'user_delex': 'no , i am all set . have a nice day . bye .', | |||
| 'resp': 'you too ! thank you', | |||
| 'sys': 'you too ! thank you', | |||
| 'pointer': '0,0,0,0,0,0', | |||
| 'match': '', | |||
| 'constraint': | |||
| "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", | |||
| 'cons_delex': '[taxi] destination departure leave', | |||
| 'sys_act': '[general] [bye]', | |||
| 'turn_num': 3, | |||
| 'turn_domain': '[general]' | |||
| }] | |||
| } | |||
| } | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| preprocessor = DialogModelingPreprocessor(model_dir=cache_path) | |||
| model = SpaceForDialogModeling( | |||
| model_dir=cache_path, | |||
| text_field=preprocessor.text_field, | |||
| config=preprocessor.config) | |||
| pipelines = [ | |||
| DialogModelingPipeline(model=model, preprocessor=preprocessor), | |||
| pipeline( | |||
| task=Tasks.dialog_modeling, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| ] | |||
| result = {} | |||
| for step, item in enumerate(self.test_case['sng0073']['log']): | |||
| user = item['user'] | |||
| print('user: {}'.format(user)) | |||
| result = pipelines[step % 2]({ | |||
| 'user_input': user, | |||
| 'history': result | |||
| }) | |||
| print('response : {}'.format(result['response'])) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| preprocessor = DialogModelingPreprocessor(model_dir=model.model_dir) | |||
| pipelines = [ | |||
| DialogModelingPipeline(model=model, preprocessor=preprocessor), | |||
| pipeline( | |||
| task=Tasks.dialog_modeling, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| ] | |||
| result = {} | |||
| for step, item in enumerate(self.test_case['sng0073']['log']): | |||
| user = item['user'] | |||
| print('user: {}'.format(user)) | |||
| result = pipelines[step % 2]({ | |||
| 'user_input': user, | |||
| 'history': result | |||
| }) | |||
| print('response : {}'.format(result['response'])) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,52 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import SbertForNLI | |||
| from modelscope.pipelines import NLIPipeline, pipeline | |||
| from modelscope.preprocessors import NLIPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class NLITest(unittest.TestCase): | |||
| model_id = 'damo/nlp_structbert_nli_chinese-base' | |||
| sentence1 = '四川商务职业学院和四川财经职业学院哪个好?' | |||
| sentence2 = '四川商务职业学院商务管理在哪个校区?' | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_direct_file_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| tokenizer = NLIPreprocessor(cache_path) | |||
| model = SbertForNLI(cache_path, tokenizer=tokenizer) | |||
| pipeline1 = NLIPipeline(model, preprocessor=tokenizer) | |||
| pipeline2 = pipeline(Tasks.nli, model=model, preprocessor=tokenizer) | |||
| print(f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' | |||
| f'pipeline1:{pipeline1(input=(self.sentence1, self.sentence2))}') | |||
| print() | |||
| print( | |||
| f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' | |||
| f'pipeline1: {pipeline2(input=(self.sentence1, self.sentence2))}') | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| tokenizer = NLIPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.nli, model=model, preprocessor=tokenizer) | |||
| print(pipeline_ins(input=(self.sentence1, self.sentence2))) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline(task=Tasks.nli, model=self.model_id) | |||
| print(pipeline_ins(input=(self.sentence1, self.sentence2))) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline(task=Tasks.nli) | |||
| print(pipeline_ins(input=(self.sentence1, self.sentence2))) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,58 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import SbertForSentimentClassification | |||
| from modelscope.pipelines import SentimentClassificationPipeline, pipeline | |||
| from modelscope.preprocessors import SentimentClassificationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class SentimentClassificationTest(unittest.TestCase): | |||
| model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | |||
| sentence1 = '启动的时候很大声音,然后就会听到1.2秒的卡察的声音,类似齿轮摩擦的声音' | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_direct_file_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| tokenizer = SentimentClassificationPreprocessor(cache_path) | |||
| model = SbertForSentimentClassification( | |||
| cache_path, tokenizer=tokenizer) | |||
| pipeline1 = SentimentClassificationPipeline( | |||
| model, preprocessor=tokenizer) | |||
| pipeline2 = pipeline( | |||
| Tasks.sentiment_classification, | |||
| model=model, | |||
| preprocessor=tokenizer) | |||
| print(f'sentence1: {self.sentence1}\n' | |||
| f'pipeline1:{pipeline1(input=self.sentence1)}') | |||
| print() | |||
| print(f'sentence1: {self.sentence1}\n' | |||
| f'pipeline1: {pipeline2(input=self.sentence1)}') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| tokenizer = SentimentClassificationPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.sentiment_classification, | |||
| model=model, | |||
| preprocessor=tokenizer) | |||
| print(pipeline_ins(input=self.sentence1)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.sentiment_classification, model=self.model_id) | |||
| print(pipeline_ins(input=self.sentence1)) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline(task=Tasks.sentiment_classification) | |||
| print(pipeline_ins(input=self.sentence1)) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -4,7 +4,7 @@ import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import StructBertForTokenClassification | |||
| from modelscope.models.nlp import SbertForTokenClassification | |||
| from modelscope.pipelines import WordSegmentationPipeline, pipeline | |||
| from modelscope.preprocessors import TokenClassifcationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| @@ -19,8 +19,7 @@ class WordSegmentationTest(unittest.TestCase): | |||
| def test_run_by_direct_model_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| tokenizer = TokenClassifcationPreprocessor(cache_path) | |||
| model = StructBertForTokenClassification( | |||
| cache_path, tokenizer=tokenizer) | |||
| model = SbertForTokenClassification(cache_path, tokenizer=tokenizer) | |||
| pipeline1 = WordSegmentationPipeline(model, preprocessor=tokenizer) | |||
| pipeline2 = pipeline( | |||
| Tasks.word_segmentation, model=model, preprocessor=tokenizer) | |||