diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 21e13252..1f8440de 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -62,6 +62,7 @@ class Pipelines(object): nli = 'nli' dialog_intent_prediction = 'dialog-intent-prediction' dialog_modeling = 'dialog-modeling' + dialog_state_tracking = 'dialog-state-tracking' zero_shot_classification = 'zero-shot-classification' # audio tasks @@ -112,6 +113,7 @@ class Preprocessors(object): sen_cls_tokenizer = 'sen-cls-tokenizer' dialog_intent_preprocessor = 'dialog-intent-preprocessor' dialog_modeling_preprocessor = 'dialog-modeling-preprocessor' + dialog_state_tracking_preprocessor = 'dialog-state-tracking-preprocessor' sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' diff --git a/modelscope/models/__init__.py b/modelscope/models/__init__.py index eec7ba26..d11151b5 100644 --- a/modelscope/models/__init__.py +++ b/modelscope/models/__init__.py @@ -15,12 +15,14 @@ except ModuleNotFoundError as e: try: from .audio.kws import GenericKeyWordSpotting from .multi_modal import OfaForImageCaptioning - from .nlp import ( - BertForMaskedLM, BertForSequenceClassification, CsanmtForTranslation, - SbertForNLI, SbertForSentenceSimilarity, - SbertForSentimentClassification, SbertForTokenClassification, - SbertForZeroShotClassification, SpaceForDialogIntent, - SpaceForDialogModeling, StructBertForMaskedLM, VecoForMaskedLM) + from .nlp import (BertForMaskedLM, BertForSequenceClassification, + CsanmtForTranslation, SbertForNLI, + SbertForSentenceSimilarity, + SbertForSentimentClassification, + SbertForTokenClassification, + SbertForZeroShotClassification, SpaceForDialogIntent, + SpaceForDialogModeling, SpaceForDialogStateTracking, + StructBertForMaskedLM, VecoForMaskedLM) from .audio.ans.frcrn import FRCRNModel except ModuleNotFoundError as e: if str(e) == "No module named 'pytorch'": diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index e14bdc9c..fb1ff063 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -9,3 +9,4 @@ from .sbert_for_token_classification import * # noqa F403 from .sbert_for_zero_shot_classification import * # noqa F403 from .space.dialog_intent_prediction_model import * # noqa F403 from .space.dialog_modeling_model import * # noqa F403 +from .space.dialog_state_tracking_model import * # noqa F403 diff --git a/modelscope/models/nlp/space/dialog_intent_prediction_model.py b/modelscope/models/nlp/space/dialog_intent_prediction_model.py index 644af4c7..a75dc1a4 100644 --- a/modelscope/models/nlp/space/dialog_intent_prediction_model.py +++ b/modelscope/models/nlp/space/dialog_intent_prediction_model.py @@ -63,15 +63,16 @@ class SpaceForDialogIntent(Model): """return the result by the model Args: - input (Dict[str, Any]): the preprocessed data + input (Dict[str, Tensor]): the preprocessed data Returns: - Dict[str, np.ndarray]: results + Dict[str, Tensor]: results Example: { - 'predictions': array([1]), # lable 0-negative 1-positive - 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), - 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value + 'pred': array([2.62349960e-03 4.12110658e-03 4.12748595e-05 3.77560973e-05 + 1.08599677e-04 1.72710388e-05 2.95618793e-05 1.93638436e-04 + 6.45841064e-05 1.15997791e-04 5.11605394e-05 9.87020373e-01 + 2.66957268e-05 4.72324500e-05 9.74208378e-05], dtype=float32) } """ import numpy as np diff --git a/modelscope/models/nlp/space/dialog_modeling_model.py b/modelscope/models/nlp/space/dialog_modeling_model.py index 872155e2..e922d073 100644 --- a/modelscope/models/nlp/space/dialog_modeling_model.py +++ b/modelscope/models/nlp/space/dialog_modeling_model.py @@ -62,15 +62,17 @@ class SpaceForDialogModeling(Model): """return the result by the model Args: - input (Dict[str, Any]): the preprocessed data + input (Dict[str, Tensor]): the preprocessed data Returns: - Dict[str, np.ndarray]: results + Dict[str, Tensor]: results Example: { - 'predictions': array([1]), # lable 0-negative 1-positive - 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), - 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value + 'labels': array([1,192,321,12]), # lable + 'resp': array([293,1023,123,1123]), #vocab label for response + 'bspn': array([123,321,2,24,1 ]), + 'aspn': array([47,8345,32,29,1983]), + 'db': array([19, 24, 20]), } """ diff --git a/modelscope/models/nlp/space/dialog_state_tracking_model.py b/modelscope/models/nlp/space/dialog_state_tracking_model.py new file mode 100644 index 00000000..30f21acb --- /dev/null +++ b/modelscope/models/nlp/space/dialog_state_tracking_model.py @@ -0,0 +1,103 @@ +import os +from typing import Any, Dict + +from modelscope.utils.constant import Tasks +from ....metainfo import Models +from ....utils.nlp.space.utils_dst import batch_to_device +from ...base import Model, Tensor +from ...builder import MODELS + +__all__ = ['SpaceForDialogStateTracking'] + + +@MODELS.register_module(Tasks.dialog_state_tracking, module_name=Models.space) +class SpaceForDialogStateTracking(Model): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the test generation model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + + super().__init__(model_dir, *args, **kwargs) + + from sofa.models.space import SpaceForDST, SpaceConfig + self.model_dir = model_dir + + self.config = SpaceConfig.from_pretrained(self.model_dir) + self.model = SpaceForDST.from_pretrained(self.model_dir) + self.model.to(self.config.device) + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + + Returns: + Dict[str, Tensor]: results + Example: + { + 'inputs': dict(input_ids, input_masks,start_pos), # tracking states + 'outputs': dict(slots_logits), + 'unique_ids': str(test-example.json-0), # default value + 'input_ids_unmasked': array([101, 7632, 1010,0,0,0]) + 'values': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]), + 'inform': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]), + 'prefix': str('final'), #default value + 'ds': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]) + } + """ + import numpy as np + import torch + + self.model.eval() + batch = input['batch'] + batch = batch_to_device(batch, self.config.device) + + features = input['features'] + diag_state = input['diag_state'] + turn_itrs = [features[i.item()].guid.split('-')[2] for i in batch[9]] + reset_diag_state = np.where(np.array(turn_itrs) == '0')[0] + for slot in self.config.dst_slot_list: + for i in reset_diag_state: + diag_state[slot][i] = 0 + + with torch.no_grad(): + inputs = { + 'input_ids': batch[0], + 'input_mask': batch[1], + 'segment_ids': batch[2], + 'start_pos': batch[3], + 'end_pos': batch[4], + 'inform_slot_id': batch[5], + 'refer_id': batch[6], + 'diag_state': diag_state, + 'class_label_id': batch[8] + } + unique_ids = [features[i.item()].guid for i in batch[9]] + values = [features[i.item()].values for i in batch[9]] + input_ids_unmasked = [ + features[i.item()].input_ids_unmasked for i in batch[9] + ] + inform = [features[i.item()].inform for i in batch[9]] + outputs = self.model(**inputs) + + # Update dialog state for next turn. + for slot in self.config.dst_slot_list: + updates = outputs[2][slot].max(1)[1] + for i, u in enumerate(updates): + if u != 0: + diag_state[slot][i] = u + + return { + 'inputs': inputs, + 'outputs': outputs, + 'unique_ids': unique_ids, + 'input_ids_unmasked': input_ids_unmasked, + 'values': values, + 'inform': inform, + 'prefix': 'final', + 'ds': input['ds'] + } diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index bdf2cc17..e2257ff4 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -41,6 +41,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/nlp_space_dialog-intent-prediction'), Tasks.dialog_modeling: (Pipelines.dialog_modeling, 'damo/nlp_space_dialog-modeling'), + Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, + 'damo/nlp_space_dialog-state-tracking'), Tasks.image_captioning: (Pipelines.image_caption, 'damo/ofa_image-caption_coco_large_en'), Tasks.image_generation: diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index 5a46b359..c0671de7 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -1,6 +1,7 @@ try: from .dialog_intent_prediction_pipeline import * # noqa F403 from .dialog_modeling_pipeline import * # noqa F403 + from .dialog_state_tracking_pipeline import * # noqa F403 from .fill_mask_pipeline import * # noqa F403 from .nli_pipeline import * # noqa F403 from .sentence_similarity_pipeline import * # noqa F403 diff --git a/modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py b/modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py index 45844b30..0fd863a6 100644 --- a/modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py +++ b/modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py @@ -1,8 +1,9 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Any, Dict +from typing import Any, Dict, Union from ...metainfo import Pipelines +from ...models import Model from ...models.nlp import SpaceForDialogIntent from ...preprocessors import DialogIntentPredictionPreprocessor from ...utils.constant import Tasks @@ -18,17 +19,22 @@ __all__ = ['DialogIntentPredictionPipeline'] module_name=Pipelines.dialog_intent_prediction) class DialogIntentPredictionPipeline(Pipeline): - def __init__(self, model: SpaceForDialogIntent, - preprocessor: DialogIntentPredictionPreprocessor, **kwargs): - """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction + def __init__(self, + model: Union[SpaceForDialogIntent, str], + preprocessor: DialogIntentPredictionPreprocessor = None, + **kwargs): + """use `model` and `preprocessor` to create a dialog intent prediction pipeline Args: - model (SequenceClassificationModel): a model instance - preprocessor (SequenceClassificationPreprocessor): a preprocessor instance + model (SpaceForDialogIntent): a model instance + preprocessor (DialogIntentPredictionPreprocessor): a preprocessor instance """ - - super().__init__(model=model, preprocessor=preprocessor, **kwargs) + model = model if isinstance( + model, SpaceForDialogIntent) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = DialogIntentPredictionPreprocessor(model.model_dir) self.model = model + super().__init__(model=model, preprocessor=preprocessor, **kwargs) self.categories = preprocessor.categories def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: diff --git a/modelscope/pipelines/nlp/dialog_modeling_pipeline.py b/modelscope/pipelines/nlp/dialog_modeling_pipeline.py index 746a6255..80a0f783 100644 --- a/modelscope/pipelines/nlp/dialog_modeling_pipeline.py +++ b/modelscope/pipelines/nlp/dialog_modeling_pipeline.py @@ -1,8 +1,9 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Any, Dict, Optional +from typing import Any, Dict, Union from ...metainfo import Pipelines +from ...models import Model from ...models.nlp import SpaceForDialogModeling from ...preprocessors import DialogModelingPreprocessor from ...utils.constant import Tasks @@ -17,17 +18,22 @@ __all__ = ['DialogModelingPipeline'] Tasks.dialog_modeling, module_name=Pipelines.dialog_modeling) class DialogModelingPipeline(Pipeline): - def __init__(self, model: SpaceForDialogModeling, - preprocessor: DialogModelingPreprocessor, **kwargs): - """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction + def __init__(self, + model: Union[SpaceForDialogModeling, str], + preprocessor: DialogModelingPreprocessor = None, + **kwargs): + """use `model` and `preprocessor` to create a dialog modleing pipeline for dialog response generation Args: - model (SequenceClassificationModel): a model instance - preprocessor (SequenceClassificationPreprocessor): a preprocessor instance + model (SpaceForDialogModeling): a model instance + preprocessor (DialogModelingPreprocessor): a preprocessor instance """ - - super().__init__(model=model, preprocessor=preprocessor, **kwargs) + model = model if isinstance( + model, SpaceForDialogModeling) else Model.from_pretrained(model) self.model = model + if preprocessor is None: + preprocessor = DialogModelingPreprocessor(model.model_dir) + super().__init__(model=model, preprocessor=preprocessor, **kwargs) self.preprocessor = preprocessor def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: @@ -43,7 +49,6 @@ class DialogModelingPipeline(Pipeline): inputs['resp']) assert len(sys_rsp) > 2 sys_rsp = sys_rsp[1:len(sys_rsp) - 1] - inputs[OutputKeys.RESPONSE] = sys_rsp return inputs diff --git a/modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py b/modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py new file mode 100644 index 00000000..9c2c9b0d --- /dev/null +++ b/modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py @@ -0,0 +1,159 @@ +from typing import Any, Dict, Union + +from ...metainfo import Pipelines +from ...models import Model, SpaceForDialogStateTracking +from ...preprocessors import DialogStateTrackingPreprocessor +from ...utils.constant import Tasks +from ..base import Pipeline +from ..builder import PIPELINES +from ..outputs import OutputKeys + +__all__ = ['DialogStateTrackingPipeline'] + + +@PIPELINES.register_module( + Tasks.dialog_state_tracking, module_name=Pipelines.dialog_state_tracking) +class DialogStateTrackingPipeline(Pipeline): + + def __init__(self, + model: Union[SpaceForDialogStateTracking, str], + preprocessor: DialogStateTrackingPreprocessor = None, + **kwargs): + """use `model` and `preprocessor` to create a dialog state tracking pipeline for + observation of dialog states tracking after many turns of open domain dialogue + + Args: + model (SpaceForDialogStateTracking): a model instance + preprocessor (DialogStateTrackingPreprocessor): a preprocessor instance + """ + + model = model if isinstance( + model, + SpaceForDialogStateTracking) else Model.from_pretrained(model) + self.model = model + if preprocessor is None: + preprocessor = DialogStateTrackingPreprocessor(model.model_dir) + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + self.tokenizer = preprocessor.tokenizer + self.config = preprocessor.config + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + + _inputs = inputs['inputs'] + _outputs = inputs['outputs'] + unique_ids = inputs['unique_ids'] + input_ids_unmasked = inputs['input_ids_unmasked'] + values = inputs['values'] + inform = inputs['inform'] + prefix = inputs['prefix'] + ds = inputs['ds'] + ds = predict_and_format(self.config, self.tokenizer, _inputs, + _outputs[2], _outputs[3], _outputs[4], + _outputs[5], unique_ids, input_ids_unmasked, + values, inform, prefix, ds) + + return {OutputKeys.DIALOG_STATES: ds} + + +def predict_and_format(config, tokenizer, features, per_slot_class_logits, + per_slot_start_logits, per_slot_end_logits, + per_slot_refer_logits, ids, input_ids_unmasked, values, + inform, prefix, ds): + import re + + prediction_list = [] + dialog_state = ds + for i in range(len(ids)): + if int(ids[i].split('-')[2]) == 0: + dialog_state = {slot: 'none' for slot in config.dst_slot_list} + + prediction = {} + prediction_addendum = {} + for slot in config.dst_slot_list: + class_logits = per_slot_class_logits[slot][i] + start_logits = per_slot_start_logits[slot][i] + end_logits = per_slot_end_logits[slot][i] + refer_logits = per_slot_refer_logits[slot][i] + + input_ids = features['input_ids'][i].tolist() + class_label_id = int(features['class_label_id'][slot][i]) + start_pos = int(features['start_pos'][slot][i]) + end_pos = int(features['end_pos'][slot][i]) + refer_id = int(features['refer_id'][slot][i]) + + class_prediction = int(class_logits.argmax()) + start_prediction = int(start_logits.argmax()) + end_prediction = int(end_logits.argmax()) + refer_prediction = int(refer_logits.argmax()) + + prediction['guid'] = ids[i].split('-') + prediction['class_prediction_%s' % slot] = class_prediction + prediction['class_label_id_%s' % slot] = class_label_id + prediction['start_prediction_%s' % slot] = start_prediction + prediction['start_pos_%s' % slot] = start_pos + prediction['end_prediction_%s' % slot] = end_prediction + prediction['end_pos_%s' % slot] = end_pos + prediction['refer_prediction_%s' % slot] = refer_prediction + prediction['refer_id_%s' % slot] = refer_id + prediction['input_ids_%s' % slot] = input_ids + + if class_prediction == config.dst_class_types.index('dontcare'): + dialog_state[slot] = 'dontcare' + elif class_prediction == config.dst_class_types.index( + 'copy_value'): + input_tokens = tokenizer.convert_ids_to_tokens( + input_ids_unmasked[i]) + dialog_state[slot] = ' '.join( + input_tokens[start_prediction:end_prediction + 1]) + dialog_state[slot] = re.sub('(^| )##', '', dialog_state[slot]) + elif 'true' in config.dst_class_types and class_prediction == config.dst_class_types.index( + 'true'): + dialog_state[slot] = 'true' + elif 'false' in config.dst_class_types and class_prediction == config.dst_class_types.index( + 'false'): + dialog_state[slot] = 'false' + elif class_prediction == config.dst_class_types.index('inform'): + # dialog_state[slot] = '§§' + inform[i][slot] + if isinstance(inform[i][slot], str): + dialog_state[slot] = inform[i][slot] + elif isinstance(inform[i][slot], list): + dialog_state[slot] = inform[i][slot][0] + # Referral case is handled below + + prediction_addendum['slot_prediction_%s' + % slot] = dialog_state[slot] + prediction_addendum['slot_groundtruth_%s' % slot] = values[i][slot] + + # Referral case. All other slot values need to be seen first in order + # to be able to do this correctly. + for slot in config.dst_slot_list: + class_logits = per_slot_class_logits[slot][i] + refer_logits = per_slot_refer_logits[slot][i] + + class_prediction = int(class_logits.argmax()) + refer_prediction = int(refer_logits.argmax()) + + if 'refer' in config.dst_class_types and class_prediction == config.dst_class_types.index( + 'refer'): + # Only slots that have been mentioned before can be referred to. + # One can think of a situation where one slot is referred to in the same utterance. + # This phenomenon is however currently not properly covered in the training data + # label generation process. + dialog_state[slot] = dialog_state[config.dst_slot_list[ + refer_prediction - 1]] + prediction_addendum['slot_prediction_%s' % + slot] = dialog_state[slot] # Value update + + prediction.update(prediction_addendum) + prediction_list.append(prediction) + + return dialog_state diff --git a/modelscope/pipelines/nlp/sentiment_classification_pipeline.py b/modelscope/pipelines/nlp/sentiment_classification_pipeline.py index db665eec..2afe64d9 100644 --- a/modelscope/pipelines/nlp/sentiment_classification_pipeline.py +++ b/modelscope/pipelines/nlp/sentiment_classification_pipeline.py @@ -74,5 +74,4 @@ class SentimentClassificationPipeline(Pipeline): probs = probs[cls_ids].tolist() cls_names = [self.model.id2label[cid] for cid in cls_ids] - return {OutputKeys.SCORES: probs, OutputKeys.LABELS: cls_names} diff --git a/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py b/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py index 617809f1..a7ea1e9a 100644 --- a/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py +++ b/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py @@ -29,7 +29,6 @@ class ZeroShotClassificationPipeline(Pipeline): preprocessor: ZeroShotClassificationPreprocessor = None, **kwargs): """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction - Args: model (SbertForSentimentClassification): a model instance preprocessor (SentimentClassificationPreprocessor): a preprocessor instance @@ -39,10 +38,8 @@ class ZeroShotClassificationPipeline(Pipeline): model = model if isinstance( model, SbertForZeroShotClassification) else Model.from_pretrained(model) - self.entailment_id = 0 self.contradiction_id = 2 - if preprocessor is None: preprocessor = ZeroShotClassificationPreprocessor(model.model_dir) model.eval() @@ -51,7 +48,6 @@ class ZeroShotClassificationPipeline(Pipeline): def _sanitize_parameters(self, **kwargs): preprocess_params = {} postprocess_params = {} - if 'candidate_labels' in kwargs: candidate_labels = kwargs.pop('candidate_labels') preprocess_params['candidate_labels'] = candidate_labels @@ -60,7 +56,6 @@ class ZeroShotClassificationPipeline(Pipeline): raise ValueError('You must include at least one label.') preprocess_params['hypothesis_template'] = kwargs.pop( 'hypothesis_template', '{}') - postprocess_params['multi_label'] = kwargs.pop('multi_label', False) return preprocess_params, {}, postprocess_params @@ -74,14 +69,11 @@ class ZeroShotClassificationPipeline(Pipeline): candidate_labels, multi_label=False) -> Dict[str, Any]: """process the prediction results - Args: inputs (Dict[str, Any]): _description_ - Returns: Dict[str, Any]: the prediction results """ - logits = inputs['logits'] if multi_label or len(candidate_labels) == 1: logits = logits[..., [self.contradiction_id, self.entailment_id]] @@ -89,7 +81,6 @@ class ZeroShotClassificationPipeline(Pipeline): else: logits = logits[..., self.entailment_id] scores = softmax(logits, axis=-1) - reversed_index = list(reversed(scores.argsort())) result = { OutputKeys.LABELS: [candidate_labels[i] for i in reversed_index], diff --git a/modelscope/pipelines/outputs.py b/modelscope/pipelines/outputs.py index b418fe7f..368586df 100644 --- a/modelscope/pipelines/outputs.py +++ b/modelscope/pipelines/outputs.py @@ -21,6 +21,7 @@ class OutputKeys(object): TRANSLATION = 'translation' RESPONSE = 'response' PREDICTION = 'prediction' + DIALOG_STATES = 'dialog_states' VIDEO_EMBEDDING = 'video_embedding' @@ -158,6 +159,7 @@ TASK_OUTPUTS = { # } Tasks.nli: [OutputKeys.SCORES, OutputKeys.LABELS], + # dialog intent prediction result for single sample # {'pred': array([2.62349960e-03, 4.12110658e-03, 4.12748595e-05, 3.77560973e-05, # 1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04, # 6.45841064e-05, 1.15997791e-04, 5.11605394e-05, 9.87020373e-01, @@ -181,9 +183,47 @@ TASK_OUTPUTS = { Tasks.dialog_intent_prediction: [OutputKeys.PREDICTION, OutputKeys.LABEL_POS, OutputKeys.LABEL], + # dialog modeling prediction result for single sample # sys : ['you', 'are', 'welcome', '.', 'have', 'a', 'great', 'day', '!'] Tasks.dialog_modeling: [OutputKeys.RESPONSE], + # dialog state tracking result for single sample + # { + # "dialog_states": { + # "taxi-leaveAt": "none", + # "taxi-destination": "none", + # "taxi-departure": "none", + # "taxi-arriveBy": "none", + # "restaurant-book_people": "none", + # "restaurant-book_day": "none", + # "restaurant-book_time": "none", + # "restaurant-food": "none", + # "restaurant-pricerange": "none", + # "restaurant-name": "none", + # "restaurant-area": "none", + # "hotel-book_people": "none", + # "hotel-book_day": "none", + # "hotel-book_stay": "none", + # "hotel-name": "none", + # "hotel-area": "none", + # "hotel-parking": "none", + # "hotel-pricerange": "cheap", + # "hotel-stars": "none", + # "hotel-internet": "none", + # "hotel-type": "true", + # "attraction-type": "none", + # "attraction-name": "none", + # "attraction-area": "none", + # "train-book_people": "none", + # "train-leaveAt": "none", + # "train-destination": "none", + # "train-day": "none", + # "train-arriveBy": "none", + # "train-departure": "none" + # } + # } + Tasks.dialog_state_tracking: [OutputKeys.DIALOG_STATES], + # ============ audio tasks =================== # audio processed for single file in PCM format diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 3bdf0dcf..962e9f6e 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -13,6 +13,7 @@ try: from .nlp import * # noqa F403 from .space.dialog_intent_prediction_preprocessor import * # noqa F403 from .space.dialog_modeling_preprocessor import * # noqa F403 + from .space.dialog_state_tracking_preprocessor import * # noqa F403 except ModuleNotFoundError as e: if str(e) == "No module named 'tensorflow'": pass diff --git a/modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py b/modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py new file mode 100644 index 00000000..6ddb9a9c --- /dev/null +++ b/modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py @@ -0,0 +1,133 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Any, Dict + +from modelscope.utils.constant import Fields +from modelscope.utils.type_assert import type_assert +from ...metainfo import Preprocessors +from ..base import Preprocessor +from ..builder import PREPROCESSORS +from .dst_processors import convert_examples_to_features, multiwoz22Processor + +__all__ = ['DialogStateTrackingPreprocessor'] + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.dialog_state_tracking_preprocessor) +class DialogStateTrackingPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + model_dir (str): model path + """ + super().__init__(*args, **kwargs) + + from sofa.models.space import SpaceTokenizer, SpaceConfig + self.model_dir: str = model_dir + self.config = SpaceConfig.from_pretrained(self.model_dir) + self.tokenizer = SpaceTokenizer.from_pretrained(self.model_dir) + self.processor = multiwoz22Processor() + + @type_assert(object, dict) + def __call__(self, data: Dict) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + 'you are so handsome.' + + Returns: + Dict[str, Any]: the preprocessed data + """ + import torch + from torch.utils.data import (DataLoader, RandomSampler, + SequentialSampler) + + utter = data['utter'] + history_states = data['history_states'] + example = self.processor.create_example( + inputs=utter, + history_states=history_states, + set_type='test', + slot_list=self.config.dst_slot_list, + label_maps={}, + append_history=True, + use_history_labels=True, + swap_utterances=True, + label_value_repetitions=True, + delexicalize_sys_utts=True, + unk_token='[UNK]', + analyze=False) + + features = convert_examples_to_features( + examples=[example], + slot_list=self.config.dst_slot_list, + class_types=self.config.dst_class_types, + model_type=self.config.model_type, + tokenizer=self.tokenizer, + max_seq_length=180, # args.max_seq_length + slot_value_dropout=(0.0)) + + all_input_ids = torch.tensor([f.input_ids for f in features], + dtype=torch.long) + all_input_mask = torch.tensor([f.input_mask for f in features], + dtype=torch.long) + all_segment_ids = torch.tensor([f.segment_ids for f in features], + dtype=torch.long) + all_example_index = torch.arange( + all_input_ids.size(0), dtype=torch.long) + f_start_pos = [f.start_pos for f in features] + f_end_pos = [f.end_pos for f in features] + f_inform_slot_ids = [f.inform_slot for f in features] + f_refer_ids = [f.refer_id for f in features] + f_diag_state = [f.diag_state for f in features] + f_class_label_ids = [f.class_label_id for f in features] + all_start_positions = {} + all_end_positions = {} + all_inform_slot_ids = {} + all_refer_ids = {} + all_diag_state = {} + all_class_label_ids = {} + for s in self.config.dst_slot_list: + all_start_positions[s] = torch.tensor([f[s] for f in f_start_pos], + dtype=torch.long) + all_end_positions[s] = torch.tensor([f[s] for f in f_end_pos], + dtype=torch.long) + all_inform_slot_ids[s] = torch.tensor( + [f[s] for f in f_inform_slot_ids], dtype=torch.long) + all_refer_ids[s] = torch.tensor([f[s] for f in f_refer_ids], + dtype=torch.long) + all_diag_state[s] = torch.tensor([f[s] for f in f_diag_state], + dtype=torch.long) + all_class_label_ids[s] = torch.tensor( + [f[s] for f in f_class_label_ids], dtype=torch.long) + dataset = [ + all_input_ids, all_input_mask, all_segment_ids, + all_start_positions, all_end_positions, all_inform_slot_ids, + all_refer_ids, all_diag_state, all_class_label_ids, + all_example_index + ] + + with torch.no_grad(): + diag_state = { + slot: + torch.tensor([0 for _ in range(self.config.eval_batch_size) + ]).to(self.config.device) + for slot in self.config.dst_slot_list + } + + if len(history_states) > 2: + ds = history_states[-2] + else: + ds = {slot: 'none' for slot in self.config.dst_slot_list} + + return { + 'batch': dataset, + 'features': features, + 'diag_state': diag_state, + 'ds': ds + } diff --git a/modelscope/preprocessors/space/dst_processors.py b/modelscope/preprocessors/space/dst_processors.py new file mode 100644 index 00000000..1f9920a9 --- /dev/null +++ b/modelscope/preprocessors/space/dst_processors.py @@ -0,0 +1,1441 @@ +# +# Copyright 2020 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# Part of this code is based on the source code of Transformers +# (arXiv:1910.03771) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +import json +import numpy as np +import six +from tqdm import tqdm + +logger = logging.getLogger(__name__) +USER_NAME = 'User' +SYSTEM_NAME = 'System' +DIALOG_ACT = 'Dialog_Act' + + +class DSTProcessor(object): + ACTS_DICT = { + 'taxi-depart': 'taxi-departure', + 'taxi-dest': 'taxi-destination', + 'taxi-leaveat': 'taxi-leaveAt', + 'taxi-arriveby': 'taxi-arriveBy', + 'train-depart': 'train-departure', + 'train-dest': 'train-destination', + 'train-leaveat': 'train-leaveAt', + 'train-arriveby': 'train-arriveBy', + 'train-bookpeople': 'train-book_people', + 'restaurant-price': 'restaurant-pricerange', + 'restaurant-bookpeople': 'restaurant-book_people', + 'restaurant-bookday': 'restaurant-book_day', + 'restaurant-booktime': 'restaurant-book_time', + 'hotel-price': 'hotel-pricerange', + 'hotel-bookpeople': 'hotel-book_people', + 'hotel-bookday': 'hotel-book_day', + 'hotel-bookstay': 'hotel-book_stay', + 'booking-bookpeople': 'booking-book_people', + 'booking-bookday': 'booking-book_day', + 'booking-bookstay': 'booking-book_stay', + 'booking-booktime': 'booking-book_time', + } + + LABEL_MAPS = {} # Loaded from file + + def __init__(self): + # Required for mapping slot names in dialogue_acts.json file + # to proper designations. + pass + + def _convert_inputs_to_utterances(self, inputs: dict, + history_states: list): + """This method is to generate the utterances with user, sys, dialog_acts and metadata, + while metadata is from the history_states or the output from the inference pipline""" + + utterances = [] + user_inputs = [] + sys_gen_inputs = [] + dialog_acts_inputs = [] + for i, item in enumerate(inputs): + name, turn = item.split('-') + if name == USER_NAME: + user_inputs.insert(int(turn) - 1, inputs[item]) + elif name == SYSTEM_NAME: + sys_gen_inputs.insert(int(turn) - 1, inputs[item]) + else: + dialog_acts_inputs.insert(int(turn) - 1, inputs[item]) + + # user is leading the topic should aways larger than sys and dialog acts + assert len(user_inputs) - 1 == len(sys_gen_inputs) + assert len(user_inputs) - 1 == len(dialog_acts_inputs) + # the history states record both user and sys states + assert len(history_states) == len(user_inputs) + len(sys_gen_inputs) + + # the dialog_act at user turn is useless + for i, item in enumerate(history_states): + utterance = {} + # the dialog_act at user turn is useless + utterance['dialog_act'] = dialog_acts_inputs[ + i // 2] if i % 2 == 1 else {} + utterance['text'] = sys_gen_inputs[ + i // 2] if i % 2 == 1 else user_inputs[i // 2] + utterance['metadata'] = item + utterance['span_info'] = [] + utterances.append(utterance) + + return utterances + + def _load_acts(self, inputs: dict, dialog_id='example.json'): + dialog_acts_inputs = [] + for i, item in enumerate(inputs): + name, turn = item.split('-') + if name == DIALOG_ACT: + dialog_acts_inputs.insert(int(turn) - 1, inputs[item]) + s_dict = {} + + for j, item in enumerate(dialog_acts_inputs): + if isinstance(item, dict): + for a in item: + aa = a.lower().split('-') + if aa[1] == 'inform' or aa[1] == 'recommend' or \ + aa[1] == 'select' or aa[1] == 'book': + for i in item[a]: + s = i[0].lower() + v = i[1].lower().strip() + if s == 'none' or v == '?' or v == 'none': + continue + slot = aa[0] + '-' + s + if slot in self.ACTS_DICT: + slot = self.ACTS_DICT[slot] + key = dialog_id, str(int(j) + 1), slot + # In case of multiple mentioned values... + # ... Option 1: Keep first informed value + if key not in s_dict: + s_dict[key] = list([v]) + # ... Option 2: Keep last informed value + # s_dict[key] = list([v]) + + return s_dict + + +class multiwoz22Processor(DSTProcessor): + + def __init__(self): + super().__init__() + + def normalize_time(self, text): + text = re.sub(r'(\d{1})(a\.?m\.?|p\.?m\.?)', r'\1 \2', + text) # am/pm without space + text = re.sub(r'(^| )(\d{1,2}) (a\.?m\.?|p\.?m\.?)', r'\1\2:00 \3', + text) # am/pm short to long form + text = re.sub( + r'(^| )(at|from|by|until|after) ?(\d{1,2}) ?(\d{2})([^0-9]|$)', + r'\1\2 \3:\4\5', text) # Missing separator + text = re.sub(r'(^| )(\d{2})[;.,](\d{2})', r'\1\2:\3', + text) # Wrong separator + text = re.sub(r'(^| )(at|from|by|until|after) ?(\d{1,2})([;., ]|$)', + r'\1\2 \3:00\4', text) # normalize simple full hour time + text = re.sub(r'(^| )(\d{1}:\d{2})', r'\g<1>0\2', + text) # Add missing leading 0 + # Map 12 hour times to 24 hour times + text = \ + re.sub( + r'(\d{2})(:\d{2}) ?p\.?m\.?', + lambda x: str(int(x.groups()[0]) + 12 + if int(x.groups()[0]) < 12 else int(x.groups()[0])) + x.groups()[1], text) + text = re.sub(r'(^| )24:(\d{2})', r'\g<1>00:\2', + text) # Correct times that use 24 as hour + return text + + def normalize_text(self, text): + text = self.normalize_time(text) + text = re.sub("n't", ' not', text) + text = re.sub('(^| )zero(-| )star([s.,? ]|$)', r'\g<1>0 star\3', text) + text = re.sub('(^| )one(-| )star([s.,? ]|$)', r'\g<1>1 star\3', text) + text = re.sub('(^| )two(-| )star([s.,? ]|$)', r'\g<1>2 star\3', text) + text = re.sub('(^| )three(-| )star([s.,? ]|$)', r'\g<1>3 star\3', text) + text = re.sub('(^| )four(-| )star([s.,? ]|$)', r'\g<1>4 star\3', text) + text = re.sub('(^| )five(-| )star([s.,? ]|$)', r'\g<1>5 star\3', text) + text = re.sub('archaelogy', 'archaeology', text) # Systematic typo + text = re.sub('guesthouse', 'guest house', text) # Normalization + text = re.sub('(^| )b ?& ?b([.,? ]|$)', r'\1bed and breakfast\2', + text) # Normalization + text = re.sub('bed & breakfast', 'bed and breakfast', + text) # Normalization + return text + + # Loads the dialogue_acts.json and returns a list + # of slot-value pairs. + def load_acts(self, input_file): + with open(input_file) as f: + acts = json.load(f) + s_dict = {} + for d in acts: + for t in acts[d]: + if int(t) % 2 == 0: + continue + # Only process, if turn has annotation + if isinstance(acts[d][t]['dialog_act'], dict): + for a in acts[d][t]['dialog_act']: + aa = a.lower().split('-') + if aa[1] == 'inform' or aa[1] == 'recommend' \ + or aa[1] == 'select' or aa[1] == 'book': + for i in acts[d][t]['dialog_act'][a]: + s = i[0].lower() + v = i[1].lower().strip() + if s == 'none' or v == '?' or v == 'none': + continue + slot = aa[0] + '-' + s + if slot in self.ACTS_DICT: + slot = self.ACTS_DICT[slot] + key = d, str(int(t) // 2 + 1), slot + # In case of multiple mentioned values... + # ... Option 1: Keep first informed value + if key not in s_dict: + s_dict[key] = list([v]) + # ... Option 2: Keep last informed value + # s_dict[key] = list([v]) + return s_dict + + # This should only contain label normalizations. All other mappings should + # be defined in LABEL_MAPS. + def normalize_label(self, slot, value_label): + # Normalization of empty slots + if value_label == '' or value_label == 'not mentioned': + return 'none' + + # Normalization of time slots + if 'leaveAt' in slot or 'arriveBy' in slot or slot == 'restaurant-book_time': + return self.normalize_time(value_label) + + # Normalization + if 'type' in slot or 'name' in slot or 'destination' in slot or 'departure' in slot: + value_label = re.sub('guesthouse', 'guest house', value_label) + + # Map to boolean slots + if slot == 'hotel-parking' or slot == 'hotel-internet': + if value_label == 'yes' or value_label == 'free': + return 'true' + if value_label == 'no': + return 'false' + if slot == 'hotel-type': + if value_label == 'hotel': + return 'true' + if value_label == 'guest house': + return 'false' + + return value_label + + def tokenize(self, utt): + utt_lower = convert_to_unicode(utt).lower() + utt_lower = self.normalize_text(utt_lower) + utt_tok = [ + tok for tok in map(str.strip, re.split(r'(\W+)', utt_lower)) + if len(tok) > 0 + ] + return utt_tok + + def delex_utt(self, utt, values, unk_token='[UNK]'): + utt_norm = self.tokenize(utt) + for s, vals in values.items(): + for v in vals: + if v != 'none': + v_norm = self.tokenize(v) + v_len = len(v_norm) + for i in range(len(utt_norm) + 1 - v_len): + if utt_norm[i:i + v_len] == v_norm: + utt_norm[i:i + v_len] = [unk_token] * v_len + return utt_norm + + def get_token_pos(self, tok_list, value_label): + find_pos = [] + found = False + label_list = [ + item for item in map(str.strip, re.split(r'(\W+)', value_label)) + if len(item) > 0 + ] + len_label = len(label_list) + for i in range(len(tok_list) + 1 - len_label): + if tok_list[i:i + len_label] == label_list: + find_pos.append((i, i + len_label)) # start, exclusive_end + found = True + return found, find_pos + + def check_label_existence(self, value_label, usr_utt_tok): + in_usr, usr_pos = self.get_token_pos(usr_utt_tok, value_label) + # If no hit even though there should be one, check for value label variants + if not in_usr and value_label in self.LABEL_MAPS: + for value_label_variant in self.LABEL_MAPS[value_label]: + in_usr, usr_pos = self.get_token_pos(usr_utt_tok, + value_label_variant) + if in_usr: + break + return in_usr, usr_pos + + def check_slot_referral(self, value_label, slot, seen_slots): + referred_slot = 'none' + if slot == 'hotel-stars' or slot == 'hotel-internet' or slot == 'hotel-parking': + return referred_slot + for s in seen_slots: + # Avoid matches for slots that share values with different meaning. + # hotel-internet and -parking are handled separately as Boolean slots. + if s == 'hotel-stars' or s == 'hotel-internet' or s == 'hotel-parking': + continue + if re.match('(hotel|restaurant)-book_people', + s) and slot == 'hotel-book_stay': + continue + if re.match('(hotel|restaurant)-book_people', + slot) and s == 'hotel-book_stay': + continue + if slot != s and (slot not in seen_slots + or seen_slots[slot] != value_label): + if seen_slots[s] == value_label: + referred_slot = s + break + elif value_label in self.LABEL_MAPS: + for value_label_variant in self.LABEL_MAPS[value_label]: + if seen_slots[s] == value_label_variant: + referred_slot = s + break + return referred_slot + + def is_in_list(self, tok, value): + found = False + tok_list = [ + item for item in map(str.strip, re.split(r'(\W+)', tok)) + if len(item) > 0 + ] + value_list = [ + item for item in map(str.strip, re.split(r'(\W+)', value)) + if len(item) > 0 + ] + tok_len = len(tok_list) + value_len = len(value_list) + for i in range(tok_len + 1 - value_len): + if tok_list[i:i + value_len] == value_list: + found = True + break + return found + + # Fuzzy matching to label informed slot values + def check_slot_inform(self, value_label, inform_label): + result = False + informed_value = 'none' + vl = ' '.join(self.tokenize(value_label)) + for il in inform_label: + if vl == il: + result = True + elif self.is_in_list(il, vl): + result = True + elif self.is_in_list(vl, il): + result = True + elif il in self.LABEL_MAPS: + for il_variant in self.LABEL_MAPS[il]: + if vl == il_variant: + result = True + break + elif self.is_in_list(il_variant, vl): + result = True + break + elif self.is_in_list(vl, il_variant): + result = True + break + elif vl in self.LABEL_MAPS: + for value_label_variant in self.LABEL_MAPS[vl]: + if value_label_variant == il: + result = True + break + elif self.is_in_list(il, value_label_variant): + result = True + break + elif self.is_in_list(value_label_variant, il): + result = True + break + if result: + informed_value = il + break + return result, informed_value + + def get_turn_label(self, value_label, inform_label, sys_utt_tok, + usr_utt_tok, slot, seen_slots, slot_last_occurrence): + usr_utt_tok_label = [0 for _ in usr_utt_tok] + informed_value = 'none' + referred_slot = 'none' + if value_label == 'none' or value_label == 'dontcare' or value_label == 'true' or value_label == 'false': + class_type = value_label + else: + in_usr, usr_pos = self.check_label_existence( + value_label, usr_utt_tok) + is_informed, informed_value = self.check_slot_inform( + value_label, inform_label) + if in_usr: + class_type = 'copy_value' + if slot_last_occurrence: + (s, e) = usr_pos[-1] + for i in range(s, e): + usr_utt_tok_label[i] = 1 + else: + for (s, e) in usr_pos: + for i in range(s, e): + usr_utt_tok_label[i] = 1 + elif is_informed: + class_type = 'inform' + else: + referred_slot = self.check_slot_referral( + value_label, slot, seen_slots) + if referred_slot != 'none': + class_type = 'refer' + else: + class_type = 'unpointable' + return informed_value, referred_slot, usr_utt_tok_label, class_type + + def _create_example(self, + utterances, + sys_inform_dict, + set_type, + slot_list, + label_maps={}, + append_history=False, + use_history_labels=False, + swap_utterances=False, + label_value_repetitions=False, + delexicalize_sys_utts=False, + unk_token='[UNK]', + analyze=False, + dialog_id='example.json'): + + # Collects all slot changes throughout the dialog + # cumulative_labels = {slot: 'none' for slot in slot_list} + + # First system utterance is empty, since multiwoz starts with user input + utt_tok_list = [[]] + mod_slots_list = [] + + # Collect all utterances and their metadata + usr_sys_switch = True + turn_itr = 0 + + inform_dict = {slot: 'none' for slot in slot_list} + for utt in utterances: + # Assert that system and user utterances alternate + is_sys_utt = utt['metadata'] != {} + if usr_sys_switch == is_sys_utt: + print( + 'WARN: Wrong order of system and user utterances. Skipping rest of the dialog %s' + % (dialog_id)) + break + usr_sys_switch = is_sys_utt + + if is_sys_utt: + turn_itr += 1 + + # Delexicalize sys utterance + if delexicalize_sys_utts and is_sys_utt: + inform_dict = {slot: 'none' for slot in slot_list} + for slot in slot_list: + if (str(dialog_id), str(turn_itr), + slot) in sys_inform_dict: + inform_dict[slot] = sys_inform_dict[(str(dialog_id), + str(turn_itr), + slot)] + utt_tok_list.append( + self.delex_utt(utt['text'], inform_dict, + unk_token)) # normalize utterances + else: + utt_tok_list.append(self.tokenize( + utt['text'])) # normalize utterances + + # Form proper (usr, sys) turns + turn_itr = 0 + diag_seen_slots_dict = {} + diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list} + diag_state = {slot: 'none' for slot in slot_list} + sys_utt_tok = [] + usr_utt_tok = [] + hst_utt_tok = [] + hst_utt_tok_label_dict = {slot: [] for slot in slot_list} + new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy() + new_diag_state = diag_state.copy() + + ###### + mod_slots_list = [] + ##### + + for i in range(0, len(utt_tok_list) - 1, 2): + sys_utt_tok_label_dict = {} + usr_utt_tok_label_dict = {} + value_dict = {} + # inform_dict = {} + inform_slot_dict = {} + referral_dict = {} + class_type_dict = {} + + # Collect turn data + if append_history: + if swap_utterances: + hst_utt_tok = usr_utt_tok + sys_utt_tok + hst_utt_tok + else: + hst_utt_tok = sys_utt_tok + usr_utt_tok + hst_utt_tok + sys_utt_tok = utt_tok_list[i] + usr_utt_tok = utt_tok_list[i + 1] + turn_slots = mod_slots_list[ + i + 1] if len(mod_slots_list) > 1 else {} + + guid = '%s-%s-%s' % (set_type, str(dialog_id), str(turn_itr)) + + if analyze: + print('%15s %2s %s ||| %s' % + (dialog_id, turn_itr, ' '.join(sys_utt_tok), + ' '.join(usr_utt_tok))) + print('%15s %2s [' % (dialog_id, turn_itr), end='') + + new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy() + new_diag_state = diag_state.copy() + for slot in slot_list: + value_label = 'none' + if slot in turn_slots: + value_label = turn_slots[slot] + # We keep the original labels so as to not + # overlook unpointable values, as well as to not + # modify any of the original labels for test sets, + # since this would make comparison difficult. + value_dict[slot] = value_label + elif label_value_repetitions and slot in diag_seen_slots_dict: + value_label = diag_seen_slots_value_dict[slot] + + # Get dialog act annotations + inform_label = list(['none']) + inform_slot_dict[slot] = 0 + if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict: + inform_label = list([ + self.normalize_label(slot, i) + for i in sys_inform_dict[(str(dialog_id), + str(turn_itr), slot)] + ]) + inform_slot_dict[slot] = 1 + elif (str(dialog_id), str(turn_itr), + 'booking-' + slot.split('-')[1]) in sys_inform_dict: + inform_label = list([ + self.normalize_label(slot, i) + for i in sys_inform_dict[(str(dialog_id), + str(turn_itr), 'booking-' + + slot.split('-')[1])] + ]) + inform_slot_dict[slot] = 1 + + (informed_value, referred_slot, usr_utt_tok_label, + class_type) = self.get_turn_label( + value_label, + inform_label, + sys_utt_tok, + usr_utt_tok, + slot, + diag_seen_slots_value_dict, + slot_last_occurrence=True) + + # inform_dict[slot] = informed_value + + # Generally don't use span prediction on sys utterance (but inform prediction instead). + sys_utt_tok_label = [0 for _ in sys_utt_tok] + + # Determine what to do with value repetitions. + # If value is unique in seen slots, then tag it, otherwise not, + # since correct slot assignment can not be guaranteed anymore. + if label_value_repetitions and slot in diag_seen_slots_dict: + if class_type == 'copy_value' and list( + diag_seen_slots_value_dict.values()).count( + value_label) > 1: + class_type = 'none' + usr_utt_tok_label = [0 for _ in usr_utt_tok_label] + + sys_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label_dict[slot] = usr_utt_tok_label + + if append_history: + if use_history_labels: + if swap_utterances: + new_hst_utt_tok_label_dict[ + slot] = usr_utt_tok_label + sys_utt_tok_label + new_hst_utt_tok_label_dict[ + slot] + else: + new_hst_utt_tok_label_dict[ + slot] = sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[ + slot] + else: + new_hst_utt_tok_label_dict[slot] = [ + 0 for _ in sys_utt_tok_label + usr_utt_tok_label + + new_hst_utt_tok_label_dict[slot] + ] + + # For now, we map all occurences of unpointable slot values + # to none. However, since the labels will still suggest + # a presence of unpointable slot values, the task of the + # DST is still to find those values. It is just not + # possible to do that via span prediction on the current input. + if class_type == 'unpointable': + class_type_dict[slot] = 'none' + referral_dict[slot] = 'none' + if analyze: + if slot not in diag_seen_slots_dict or value_label != diag_seen_slots_value_dict[ + slot]: + print('(%s): %s, ' % (slot, value_label), end='') + elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] \ + and class_type != 'copy_value' and class_type != 'inform': + # If slot has seen before and its class type did not change, label this slot a not present, + # assuming that the slot has not actually been mentioned in this turn. + # Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform, + # this must mean there is evidence in the original labels, therefore consider + # them as mentioned again. + class_type_dict[slot] = 'none' + referral_dict[slot] = 'none' + else: + class_type_dict[slot] = class_type + referral_dict[slot] = referred_slot + # Remember that this slot was mentioned during this dialog already. + if class_type != 'none': + diag_seen_slots_dict[slot] = class_type + diag_seen_slots_value_dict[slot] = value_label + new_diag_state[slot] = class_type + # Unpointable is not a valid class, therefore replace with + # some valid class for now... + if class_type == 'unpointable': + new_diag_state[slot] = 'copy_value' + + if analyze: + print(']') + + if swap_utterances: + txt_a = usr_utt_tok + txt_b = sys_utt_tok + txt_a_lbl = usr_utt_tok_label_dict + txt_b_lbl = sys_utt_tok_label_dict + else: + txt_a = sys_utt_tok + txt_b = usr_utt_tok + txt_a_lbl = sys_utt_tok_label_dict + txt_b_lbl = usr_utt_tok_label_dict + """ + text_a: dialog text + text_b: dialog text + history: dialog text + text_a_label: label,ignore during inference,turns to start/end pos + text_b_label: label,ignore during inference,turns to start/end pos + history_label: label,ignore during inference,turns to start/end pos + values: ignore during inference + inform_label: ignore during inference + inform_slot_label: input, system dialog action + refer_label: label,ignore during inference,turns to start/end pos refer_id + diag_state: input, history dialog state + class_label: label,ignore during inference,turns to start/end pos class_label_id + """ + example = DSTExample( + guid=guid, + text_a=txt_a, + text_b=txt_b, + history=hst_utt_tok, + text_a_label=txt_a_lbl, + text_b_label=txt_b_lbl, + history_label=hst_utt_tok_label_dict, + values=diag_seen_slots_value_dict.copy(), + inform_label=inform_dict, + inform_slot_label=inform_slot_dict, + refer_label=referral_dict, + diag_state=diag_state, + class_label=class_type_dict) + # Update some variables. + hst_utt_tok_label_dict = new_hst_utt_tok_label_dict.copy() + diag_state = new_diag_state.copy() + + turn_itr += 1 + return example + + def create_example(self, + inputs, + history_states, + set_type, + slot_list, + label_maps={}, + append_history=False, + use_history_labels=False, + swap_utterances=False, + label_value_repetitions=False, + delexicalize_sys_utts=False, + unk_token='[UNK]', + analyze=False, + dialog_id='0'): + utterances = self._convert_inputs_to_utterances(inputs, history_states) + sys_inform_dict = self._load_acts(inputs) + self.LABEL_MAPS = label_maps + example = self._create_example(utterances, sys_inform_dict, set_type, + slot_list, label_maps, append_history, + use_history_labels, swap_utterances, + label_value_repetitions, + delexicalize_sys_utts, unk_token, + analyze) + + return example + + def create_examples(self, + input_file, + acts_file, + set_type, + slot_list, + label_maps={}, + append_history=False, + use_history_labels=False, + swap_utterances=False, + label_value_repetitions=False, + delexicalize_sys_utts=False, + unk_token='[UNK]', + analyze=False): + """Read a DST json file into a list of DSTExample.""" + + sys_inform_dict = self.load_acts(acts_file) + + with open(input_file, 'r', encoding='utf-8') as reader: + input_data = json.load(reader) + + self.LABEL_MAPS = label_maps + + examples = [] + for dialog_id in tqdm(input_data): + entry = input_data[dialog_id] + utterances = entry['log'] + + example = self._create_example( + utterances, sys_inform_dict, set_type, slot_list, label_maps, + append_history, use_history_labels, swap_utterances, + label_value_repetitions, delexicalize_sys_utts, unk_token, + analyze) + examples.append(example) + + return examples + + +class DSTExample(object): + """ + A single training/test example for the DST dataset. + """ + + def __init__(self, + guid, + text_a, + text_b, + history, + text_a_label=None, + text_b_label=None, + history_label=None, + values=None, + inform_label=None, + inform_slot_label=None, + refer_label=None, + diag_state=None, + class_label=None): + self.guid = guid + self.text_a = text_a + self.text_b = text_b + self.history = history + self.text_a_label = text_a_label + self.text_b_label = text_b_label + self.history_label = history_label + self.values = values + self.inform_label = inform_label + self.inform_slot_label = inform_slot_label + self.refer_label = refer_label + self.diag_state = diag_state + self.class_label = class_label + + def __str__(self): + return self.__repr__() + + def __repr__(self): + s_dict = dict() + s_dict['guid'] = self.guid + s_dict['text_a'] = self.text_a + s_dict['text_b'] = self.text_b + s_dict['history'] = self.history + if self.text_a_label: + s_dict['text_a_label'] = self.text_a_label + if self.text_b_label: + s_dict['text_b_label'] = self.text_b_label + if self.history_label: + s_dict['history_label'] = self.history_label + if self.values: + s_dict['values'] = self.values + if self.inform_label: + s_dict['inform_label'] = self.inform_label + if self.inform_slot_label: + s_dict['inform_slot_label'] = self.inform_slot_label + if self.refer_label: + s_dict['refer_label'] = self.refer_label + if self.diag_state: + s_dict['diag_state'] = self.diag_state + if self.class_label: + s_dict['class_label'] = self.class_label + + s = json.dumps(s_dict) + return s + + +class InputFeatures(object): + """A single set of features of data.""" + + def __init__(self, + input_ids, + input_ids_unmasked, + input_mask, + segment_ids, + start_pos=None, + end_pos=None, + values=None, + inform=None, + inform_slot=None, + refer_id=None, + diag_state=None, + class_label_id=None, + guid='NONE'): + self.guid = guid + self.input_ids = input_ids + self.input_ids_unmasked = input_ids_unmasked + self.input_mask = input_mask + self.segment_ids = segment_ids + self.start_pos = start_pos + self.end_pos = end_pos + self.values = values + self.inform = inform + self.inform_slot = inform_slot + self.refer_id = refer_id + self.diag_state = diag_state + self.class_label_id = class_label_id + + +def convert_examples_to_features(examples, + slot_list, + class_types, + model_type, + tokenizer, + max_seq_length, + slot_value_dropout=0.0): + """Loads a data file into a list of `InputBatch`s.""" + + if model_type == 'bert': + model_specs = { + 'MODEL_TYPE': 'bert', + 'CLS_TOKEN': '[CLS]', + 'UNK_TOKEN': '[UNK]', + 'SEP_TOKEN': '[SEP]', + 'TOKEN_CORRECTION': 4 + } + else: + logger.error('Unknown model type (%s). Aborting.' % (model_type)) + exit(1) + + def _tokenize_text_and_label(text, text_label_dict, slot, tokenizer, + model_specs, slot_value_dropout): + joint_text_label = [0 for _ in text_label_dict[slot] + ] # joint all slots' label + for slot_text_label in text_label_dict.values(): + for idx, label in enumerate(slot_text_label): + if label == 1: + joint_text_label[idx] = 1 + + text_label = text_label_dict[slot] + tokens = [] + tokens_unmasked = [] + token_labels = [] + for token, token_label, joint_label in zip(text, text_label, + joint_text_label): + token = convert_to_unicode(token) + sub_tokens = tokenizer.tokenize(token) # Most time intensive step + tokens_unmasked.extend(sub_tokens) + if slot_value_dropout == 0.0 or joint_label == 0: + tokens.extend(sub_tokens) + else: + rn_list = np.random.random_sample((len(sub_tokens), )) + for rn, sub_token in zip(rn_list, sub_tokens): + if rn > slot_value_dropout: + tokens.append(sub_token) + else: + tokens.append(model_specs['UNK_TOKEN']) + token_labels.extend([token_label for _ in sub_tokens]) + assert len(tokens) == len(token_labels) + assert len(tokens_unmasked) == len(token_labels) + return tokens, tokens_unmasked, token_labels + + def _truncate_seq_pair(tokens_a, tokens_b, history, max_length): + """Truncates a sequence pair in place to the maximum length. + Copied from bert/run_classifier.py + """ + # This is a simple heuristic which will always truncate the longer sequence + # one token at a time. This makes more sense than truncating an equal percent + # of tokens from each, since if one sequence is very short then each token + # that's truncated likely contains more information than a longer sequence. + while True: + total_length = len(tokens_a) + len(tokens_b) + len(history) + if total_length <= max_length: + break + if len(history) > 0: + history.pop() + elif len(tokens_a) > len(tokens_b): + tokens_a.pop() + else: + tokens_b.pop() + + def _truncate_length_and_warn(tokens_a, tokens_b, history, max_seq_length, + model_specs, guid): + # Modifies `tokens_a` and `tokens_b` in place so that the total + # length is less than the specified length. + # Account for [CLS], [SEP], [SEP], [SEP] with "- 4" (BERT) + if len(tokens_a) + len(tokens_b) + len( + history) > max_seq_length - model_specs['TOKEN_CORRECTION']: + # logger.info('Truncate Example %s. Total len=%d.' % + # (guid, len(tokens_a) + len(tokens_b) + len(history))) + input_text_too_long = True + else: + input_text_too_long = False + _truncate_seq_pair(tokens_a, tokens_b, history, + max_seq_length - model_specs['TOKEN_CORRECTION']) + return input_text_too_long + + def _get_token_label_ids(token_labels_a, token_labels_b, + token_labels_history, max_seq_length, + model_specs): + token_label_ids = [] + token_label_ids.append(0) # [CLS] + for token_label in token_labels_a: + token_label_ids.append(token_label) + token_label_ids.append(0) # [SEP] + for token_label in token_labels_b: + token_label_ids.append(token_label) + token_label_ids.append(0) # [SEP] + for token_label in token_labels_history: + token_label_ids.append(token_label) + token_label_ids.append(0) # [SEP] + while len(token_label_ids) < max_seq_length: + token_label_ids.append(0) # padding + assert len(token_label_ids) == max_seq_length + return token_label_ids + + def _get_start_end_pos(class_type, token_label_ids, max_seq_length): + if class_type == 'copy_value' and 1 not in token_label_ids: + class_type = 'none' + start_pos = 0 + end_pos = 0 + if 1 in token_label_ids: + start_pos = token_label_ids.index(1) + # Parsing is supposed to find only first location of wanted value + if 0 not in token_label_ids[start_pos:]: + end_pos = len(token_label_ids[start_pos:]) + start_pos - 1 + else: + end_pos = token_label_ids[start_pos:].index(0) + start_pos - 1 + for i in range(max_seq_length): + if i >= start_pos and i <= end_pos: + assert token_label_ids[i] == 1 + return class_type, start_pos, end_pos + + def _get_transformer_input(tokens_a, tokens_b, history, max_seq_length, + tokenizer, model_specs): + # The convention in BERT is: + # (a) For sequence pairs: + # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] + # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 + # (b) For single sequences: + # tokens: [CLS] the dog is hairy . [SEP] + # type_ids: 0 0 0 0 0 0 0 + # + # Where "type_ids" are used to indicate whether this is the first + # sequence or the second sequence. The embedding vectors for `type=0` and + # `type=1` were learned during pre-training and are added to the wordpiece + # embedding vector (and position vector). This is not *strictly* necessary + # since the [SEP] token unambiguously separates the sequences, but it makes + # it easier for the model to learn the concept of sequences. + # + # For classification tasks, the first vector (corresponding to [CLS]) is + # used as the "sentence vector". Note that this only makes sense because + # the entire model is fine-tuned. + tokens = [] + segment_ids = [] + tokens.append(model_specs['CLS_TOKEN']) + segment_ids.append(0) + for token in tokens_a: + tokens.append(token) + segment_ids.append(0) + tokens.append(model_specs['SEP_TOKEN']) + segment_ids.append(0) + for token in tokens_b: + tokens.append(token) + segment_ids.append(1) + tokens.append(model_specs['SEP_TOKEN']) + segment_ids.append(1) + for token in history: + tokens.append(token) + segment_ids.append(1) + tokens.append(model_specs['SEP_TOKEN']) + segment_ids.append(1) + input_ids = tokenizer.convert_tokens_to_ids(tokens) + # The mask has 1 for real tokens and 0 for padding tokens. Only real + # tokens are attended to. + input_mask = [1] * len(input_ids) + # Zero-pad up to the sequence length. + while len(input_ids) < max_seq_length: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + assert len(input_ids) == max_seq_length + assert len(input_mask) == max_seq_length + assert len(segment_ids) == max_seq_length + return tokens, input_ids, input_mask, segment_ids + + total_cnt = 0 + too_long_cnt = 0 + + refer_list = ['none'] + slot_list + + features = [] + # Convert single example + for (example_index, example) in enumerate(examples): + + total_cnt += 1 + + value_dict = {} + inform_dict = {} + inform_slot_dict = {} + refer_id_dict = {} + diag_state_dict = {} + class_label_id_dict = {} + start_pos_dict = {} + end_pos_dict = {} + for slot in slot_list: + tokens_a, tokens_a_unmasked, token_labels_a = _tokenize_text_and_label( + example.text_a, example.text_a_label, slot, tokenizer, + model_specs, slot_value_dropout) + tokens_b, tokens_b_unmasked, token_labels_b = _tokenize_text_and_label( + example.text_b, example.text_b_label, slot, tokenizer, + model_specs, slot_value_dropout) + tokens_history, tokens_history_unmasked, token_labels_history = _tokenize_text_and_label( + example.history, example.history_label, slot, tokenizer, + model_specs, slot_value_dropout) + + input_text_too_long = _truncate_length_and_warn( + tokens_a, tokens_b, tokens_history, max_seq_length, + model_specs, example.guid) + + if input_text_too_long: + + token_labels_a = token_labels_a[:len(tokens_a)] + token_labels_b = token_labels_b[:len(tokens_b)] + token_labels_history = token_labels_history[:len(tokens_history + )] + tokens_a_unmasked = tokens_a_unmasked[:len(tokens_a)] + tokens_b_unmasked = tokens_b_unmasked[:len(tokens_b)] + tokens_history_unmasked = tokens_history_unmasked[:len( + tokens_history)] + + assert len(token_labels_a) == len(tokens_a) + assert len(token_labels_b) == len(tokens_b) + assert len(token_labels_history) == len(tokens_history) + assert len(token_labels_a) == len(tokens_a_unmasked) + assert len(token_labels_b) == len(tokens_b_unmasked) + assert len(token_labels_history) == len(tokens_history_unmasked) + token_label_ids = _get_token_label_ids(token_labels_a, + token_labels_b, + token_labels_history, + max_seq_length, model_specs) + + value_dict[slot] = example.values[slot] + inform_dict[slot] = example.inform_label[slot] + + class_label_mod, start_pos_dict[slot], end_pos_dict[ + slot] = _get_start_end_pos(example.class_label[slot], + token_label_ids, max_seq_length) + if class_label_mod != example.class_label[slot]: + example.class_label[slot] = class_label_mod + inform_slot_dict[slot] = example.inform_slot_label[slot] + refer_id_dict[slot] = refer_list.index(example.refer_label[slot]) + diag_state_dict[slot] = class_types.index(example.diag_state[slot]) + class_label_id_dict[slot] = class_types.index( + example.class_label[slot]) + + if input_text_too_long: + too_long_cnt += 1 + + tokens, input_ids, input_mask, segment_ids = _get_transformer_input( + tokens_a, tokens_b, tokens_history, max_seq_length, tokenizer, + model_specs) + if slot_value_dropout > 0.0: + _, input_ids_unmasked, _, _ = _get_transformer_input( + tokens_a_unmasked, tokens_b_unmasked, tokens_history_unmasked, + max_seq_length, tokenizer, model_specs) + else: + input_ids_unmasked = input_ids + + assert (len(input_ids) == len(input_ids_unmasked)) + + features.append( + InputFeatures( + guid=example.guid, + input_ids=input_ids, + input_ids_unmasked=input_ids_unmasked, + input_mask=input_mask, + segment_ids=segment_ids, + start_pos=start_pos_dict, + end_pos=end_pos_dict, + values=value_dict, + inform=inform_dict, + inform_slot=inform_slot_dict, + refer_id=refer_id_dict, + diag_state=diag_state_dict, + class_label_id=class_label_id_dict)) + + return features + + +# From bert.tokenization (TF code) +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode('utf-8', 'ignore') + else: + raise ValueError('Unsupported string type: %s' % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text.decode('utf-8', 'ignore') + elif isinstance(text, unicode): + return text + else: + raise ValueError('Unsupported string type: %s' % (type(text))) + else: + raise ValueError('Not running on Python2 or Python 3?') + + +if __name__ == '__main__': + processor = multiwoz22Processor() + set_type = 'test' + slot_list = [ + 'taxi-leaveAt', 'taxi-destination', 'taxi-departure', 'taxi-arriveBy', + 'restaurant-book_people', 'restaurant-book_day', + 'restaurant-book_time', 'restaurant-food', 'restaurant-pricerange', + 'restaurant-name', 'restaurant-area', 'hotel-book_people', + 'hotel-book_day', 'hotel-book_stay', 'hotel-name', 'hotel-area', + 'hotel-parking', 'hotel-pricerange', 'hotel-stars', 'hotel-internet', + 'hotel-type', 'attraction-type', 'attraction-name', 'attraction-area', + 'train-book_people', 'train-leaveAt', 'train-destination', 'train-day', + 'train-arriveBy', 'train-departure' + ] + append_history = True + use_history_labels = True + swap_utterances = True + label_value_repetitions = True + delexicalize_sys_utts = True, + unk_token = '[UNK]' + analyze = False + + utter1 = { + 'User-1': + 'am looking for a place to to stay that has cheap price range it should be in a type of hotel' + } + history_states1 = [ + {}, + ] + utter2 = { + 'User-1': + 'am looking for a place to to stay that has cheap price range it should be in a type of hotel', + 'System-1': + 'Okay, do you have a specific area you want to stay in?', + 'Dialog_Act-1': { + 'Hotel-Request': [['Area', '?']] + }, + 'User-2': + 'no, i just need to make sure it\'s cheap. oh, and i need parking', + } + + history_states2 = [{}, { + 'taxi': { + 'book': { + 'booked': [] + }, + 'semi': { + 'leaveAt': '', + 'destination': '', + 'departure': '', + 'arriveBy': '' + } + }, + 'police': { + 'book': { + 'booked': [] + }, + 'semi': {} + }, + 'restaurant': { + 'book': { + 'booked': [], + 'people': '', + 'day': '', + 'time': '' + }, + 'semi': { + 'food': '', + 'pricerange': '', + 'name': '', + 'area': '' + } + }, + 'hospital': { + 'book': { + 'booked': [] + }, + 'semi': { + 'department': '' + } + }, + 'hotel': { + 'book': { + 'booked': [], + 'people': '', + 'day': '', + 'stay': '' + }, + 'semi': { + 'name': 'not mentioned', + 'area': 'not mentioned', + 'parking': 'not mentioned', + 'pricerange': 'cheap', + 'stars': 'not mentioned', + 'internet': 'not mentioned', + 'type': 'hotel' + } + }, + 'attraction': { + 'book': { + 'booked': [] + }, + 'semi': { + 'type': '', + 'name': '', + 'area': '' + } + }, + 'train': { + 'book': { + 'booked': [], + 'people': '' + }, + 'semi': { + 'leaveAt': '', + 'destination': '', + 'day': '', + 'arriveBy': '', + 'departure': '' + } + } + }, {}] + + utter3 = { + 'User-1': + 'am looking for a place to to stay that has cheap price range it should be in a type of hotel', + 'System-1': 'Okay, do you have a specific area you want to stay in?', + 'Dialog_Act-1': { + 'Hotel-Request': [['Area', '?']] + }, + 'User-2': + 'no, i just need to make sure it\'s cheap. oh, and i need parking', + 'System-2': + 'I found 1 cheap hotel for you that includes parking. Do you like me to book it?', + 'Dialog_Act-2': { + 'Booking-Inform': [['none', 'none']], + 'Hotel-Inform': [['Price', 'cheap'], ['Choice', '1'], + ['Parking', 'none']] + }, + 'User-3': 'Yes, please. 6 people 3 nights starting on tuesday.' + } + + history_states3 = [{}, { + 'taxi': { + 'book': { + 'booked': [] + }, + 'semi': { + 'leaveAt': '', + 'destination': '', + 'departure': '', + 'arriveBy': '' + } + }, + 'police': { + 'book': { + 'booked': [] + }, + 'semi': {} + }, + 'restaurant': { + 'book': { + 'booked': [], + 'people': '', + 'day': '', + 'time': '' + }, + 'semi': { + 'food': '', + 'pricerange': '', + 'name': '', + 'area': '' + } + }, + 'hospital': { + 'book': { + 'booked': [] + }, + 'semi': { + 'department': '' + } + }, + 'hotel': { + 'book': { + 'booked': [], + 'people': '', + 'day': '', + 'stay': '' + }, + 'semi': { + 'name': 'not mentioned', + 'area': 'not mentioned', + 'parking': 'not mentioned', + 'pricerange': 'cheap', + 'stars': 'not mentioned', + 'internet': 'not mentioned', + 'type': 'hotel' + } + }, + 'attraction': { + 'book': { + 'booked': [] + }, + 'semi': { + 'type': '', + 'name': '', + 'area': '' + } + }, + 'train': { + 'book': { + 'booked': [], + 'people': '' + }, + 'semi': { + 'leaveAt': '', + 'destination': '', + 'day': '', + 'arriveBy': '', + 'departure': '' + } + } + }, {}, { + 'taxi': { + 'book': { + 'booked': [] + }, + 'semi': { + 'leaveAt': '', + 'destination': '', + 'departure': '', + 'arriveBy': '' + } + }, + 'police': { + 'book': { + 'booked': [] + }, + 'semi': {} + }, + 'restaurant': { + 'book': { + 'booked': [], + 'people': '', + 'day': '', + 'time': '' + }, + 'semi': { + 'food': '', + 'pricerange': '', + 'name': '', + 'area': '' + } + }, + 'hospital': { + 'book': { + 'booked': [] + }, + 'semi': { + 'department': '' + } + }, + 'hotel': { + 'book': { + 'booked': [], + 'people': '', + 'day': '', + 'stay': '' + }, + 'semi': { + 'name': 'not mentioned', + 'area': 'not mentioned', + 'parking': 'yes', + 'pricerange': 'cheap', + 'stars': 'not mentioned', + 'internet': 'not mentioned', + 'type': 'hotel' + } + }, + 'attraction': { + 'book': { + 'booked': [] + }, + 'semi': { + 'type': '', + 'name': '', + 'area': '' + } + }, + 'train': { + 'book': { + 'booked': [], + 'people': '' + }, + 'semi': { + 'leaveAt': '', + 'destination': '', + 'day': '', + 'arriveBy': '', + 'departure': '' + } + } + }, {}] + + example = processor.create_example(utter2, history_states2, set_type, + slot_list, {}, append_history, + use_history_labels, swap_utterances, + label_value_repetitions, + delexicalize_sys_utts, unk_token, + analyze) + print(f'utterances is {example}') diff --git a/modelscope/preprocessors/space/tensorlistdataset.py b/modelscope/preprocessors/space/tensorlistdataset.py new file mode 100644 index 00000000..45243261 --- /dev/null +++ b/modelscope/preprocessors/space/tensorlistdataset.py @@ -0,0 +1,59 @@ +# +# Copyright 2020 Heinrich Heine University Duesseldorf +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch.utils.data import Dataset + + +class TensorListDataset(Dataset): + r"""Dataset wrapping tensors, tensor dicts and tensor lists. + + Arguments: + *data (Tensor or dict or list of Tensors): tensors that have the same size + of the first dimension. + """ + + def __init__(self, *data): + if isinstance(data[0], dict): + size = list(data[0].values())[0].size(0) + elif isinstance(data[0], list): + size = data[0][0].size(0) + else: + size = data[0].size(0) + for element in data: + if isinstance(element, dict): + assert all( + size == tensor.size(0) + for name, tensor in element.items()) # dict of tensors + elif isinstance(element, list): + assert all(size == tensor.size(0) + for tensor in element) # list of tensors + else: + assert size == element.size(0) # tensor + self.size = size + self.data = data + + def __getitem__(self, index): + result = [] + for element in self.data: + if isinstance(element, dict): + result.append({k: v[index] for k, v in element.items()}) + elif isinstance(element, list): + result.append(v[index] for v in element) + else: + result.append(element[index]) + return tuple(result) + + def __len__(self): + return self.size diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index d4a19304..640e55e2 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -48,6 +48,7 @@ class Tasks(object): text_generation = 'text-generation' dialog_modeling = 'dialog-modeling' dialog_intent_prediction = 'dialog-intent-prediction' + dialog_state_tracking = 'dialog-state-tracking' table_question_answering = 'table-question-answering' feature_extraction = 'feature-extraction' fill_mask = 'fill-mask' diff --git a/modelscope/utils/nlp/space/utils_dst.py b/modelscope/utils/nlp/space/utils_dst.py new file mode 100644 index 00000000..2a7e67d7 --- /dev/null +++ b/modelscope/utils/nlp/space/utils_dst.py @@ -0,0 +1,10 @@ +def batch_to_device(batch, device): + batch_on_device = [] + for element in batch: + if isinstance(element, dict): + batch_on_device.append( + {k: v.to(device) + for k, v in element.items()}) + else: + batch_on_device.append(element.to(device)) + return tuple(batch_on_device) diff --git a/requirements/nlp.txt b/requirements/nlp.txt index 5eb76494..5407c713 100644 --- a/requirements/nlp.txt +++ b/requirements/nlp.txt @@ -1,3 +1,3 @@ -https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz +http://ait-public.oss-cn-hangzhou-zmf.aliyuncs.com/jizhu/en_core_web_sm-2.3.1.tar.gz sofa==1.0.5 spacy>=2.3.5 diff --git a/tests/pipelines/test_dialog_intent_prediction.py b/tests/pipelines/test_dialog_intent_prediction.py index 051f979b..f26211d3 100644 --- a/tests/pipelines/test_dialog_intent_prediction.py +++ b/tests/pipelines/test_dialog_intent_prediction.py @@ -18,7 +18,7 @@ class DialogIntentPredictionTest(unittest.TestCase): ] @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_run(self): + def test_run_by_direct_model_download(self): cache_path = snapshot_download(self.model_id) preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path) model = SpaceForDialogIntent( @@ -56,6 +56,20 @@ class DialogIntentPredictionTest(unittest.TestCase): for my_pipeline, item in list(zip(pipelines, self.test_case)): print(my_pipeline(item)) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipelines = [ + pipeline(task=Tasks.dialog_intent_prediction, model=self.model_id) + ] + for my_pipeline, item in list(zip(pipelines, self.test_case)): + print(my_pipeline(item)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipelines = [pipeline(task=Tasks.dialog_intent_prediction)] + for my_pipeline, item in list(zip(pipelines, self.test_case)): + print(my_pipeline(item)) + if __name__ == '__main__': unittest.main() diff --git a/tests/pipelines/test_dialog_modeling.py b/tests/pipelines/test_dialog_modeling.py index 7279bbff..83157317 100644 --- a/tests/pipelines/test_dialog_modeling.py +++ b/tests/pipelines/test_dialog_modeling.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import unittest +from typing import List from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model @@ -89,8 +90,22 @@ class DialogModelingTest(unittest.TestCase): } } + def generate_and_print_dialog_response( + self, pipelines: List[DialogModelingPipeline]): + + result = {} + for step, item in enumerate(self.test_case['sng0073']['log']): + user = item['user'] + print('user: {}'.format(user)) + + result = pipelines[step % 2]({ + 'user_input': user, + 'history': result + }) + print('response : {}'.format(result['response'])) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_run(self): + def test_run_by_direct_model_download(self): cache_path = snapshot_download(self.model_id) @@ -106,17 +121,7 @@ class DialogModelingTest(unittest.TestCase): model=model, preprocessor=preprocessor) ] - - result = {} - for step, item in enumerate(self.test_case['sng0073']['log']): - user = item['user'] - print('user: {}'.format(user)) - - result = pipelines[step % 2]({ - 'user_input': user, - 'history': result - }) - print('response : {}'.format(result['response'])) + self.generate_and_print_dialog_response(pipelines) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_from_modelhub(self): @@ -131,16 +136,23 @@ class DialogModelingTest(unittest.TestCase): preprocessor=preprocessor) ] - result = {} - for step, item in enumerate(self.test_case['sng0073']['log']): - user = item['user'] - print('user: {}'.format(user)) + self.generate_and_print_dialog_response(pipelines) - result = pipelines[step % 2]({ - 'user_input': user, - 'history': result - }) - print('response : {}'.format(result['response'])) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipelines = [ + pipeline(task=Tasks.dialog_modeling, model=self.model_id), + pipeline(task=Tasks.dialog_modeling, model=self.model_id) + ] + self.generate_and_print_dialog_response(pipelines) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipelines = [ + pipeline(task=Tasks.dialog_modeling), + pipeline(task=Tasks.dialog_modeling) + ] + self.generate_and_print_dialog_response(pipelines) if __name__ == '__main__': diff --git a/tests/pipelines/test_dialog_state_tracking.py b/tests/pipelines/test_dialog_state_tracking.py new file mode 100644 index 00000000..2110adba --- /dev/null +++ b/tests/pipelines/test_dialog_state_tracking.py @@ -0,0 +1,143 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest +from typing import List + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model, SpaceForDialogStateTracking +from modelscope.pipelines import DialogStateTrackingPipeline, pipeline +from modelscope.preprocessors import DialogStateTrackingPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class DialogStateTrackingTest(unittest.TestCase): + model_id = 'damo/nlp_space_dialog-state-tracking' + test_case = [{ + 'User-1': + 'Hi, I\'m looking for a train that is going to cambridge and arriving there by 20:45, ' + 'is there anything like that?' + }, { + 'System-1': + 'There are over 1,000 trains like that. Where will you be departing from?', + 'Dialog_Act-1': { + 'Train-Inform': [['Choice', 'over 1'], ['Choice', '000']], + 'Train-Request': [['Depart', '?']] + }, + 'User-2': 'I am departing from birmingham new street.' + }, { + 'System-2': 'Can you confirm your desired travel day?', + 'Dialog_Act-2': { + 'Train-Request': [['Day', '?']] + }, + 'User-3': 'I would like to leave on wednesday' + }, { + 'System-3': + 'I show a train leaving birmingham new street at 17:40 and arriving at 20:23 on Wednesday. ' + 'Will this work for you?', + 'Dialog_Act-3': { + 'Train-Inform': [['Arrive', '20:23'], ['Leave', '17:40'], + ['Day', 'Wednesday'], + ['Depart', 'birmingham new street']] + }, + 'User-4': + 'That will, yes. Please make a booking for 5 people please.', + }, { + 'System-4': + 'I\'ve booked your train tickets, and your reference number is A9NHSO9Y.', + 'Dialog_Act-4': { + 'Train-OfferBooked': [['Ref', 'A9NHSO9Y']] + }, + 'User-5': + 'Thanks so much. I would also need a place to say. ' + 'I am looking for something with 4 stars and has free wifi.' + }, { + 'System-5': + 'How about the cambridge belfry? ' + 'It has all the attributes you requested and a great name! ' + 'Maybe even a real belfry?', + 'Dialog_Act-5': { + 'Hotel-Recommend': [['Name', 'the cambridge belfry']] + }, + 'User-6': + 'That sounds great, could you make a booking for me please?', + }, { + 'System-6': + 'What day would you like your booking for?', + 'Dialog_Act-6': { + 'Booking-Request': [['Day', '?']] + }, + 'User-7': + 'Please book it for Wednesday for 5 people and 5 nights, please.', + }, { + 'System-7': 'Booking was successful. Reference number is : 5NAWGJDC.', + 'Dialog_Act-7': { + 'Booking-Book': [['Ref', '5NAWGJDC']] + }, + 'User-8': 'Thank you, goodbye', + }] + + def tracking_and_print_dialog_states( + self, pipelines: List[DialogStateTrackingPipeline]): + import json + pipelines_len = len(pipelines) + history_states = [{}] + utter = {} + for step, item in enumerate(self.test_case): + utter.update(item) + result = pipelines[step % pipelines_len]({ + 'utter': + utter, + 'history_states': + history_states + }) + print(json.dumps(result)) + + history_states.extend([result['dialog_states'], {}]) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + + model = SpaceForDialogStateTracking(cache_path) + preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path) + pipelines = [ + DialogStateTrackingPipeline( + model=model, preprocessor=preprocessor), + pipeline( + task=Tasks.dialog_state_tracking, + model=model, + preprocessor=preprocessor) + ] + self.tracking_and_print_dialog_states(pipelines) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + preprocessor = DialogStateTrackingPreprocessor( + model_dir=model.model_dir) + pipelines = [ + DialogStateTrackingPipeline( + model=model, preprocessor=preprocessor), + pipeline( + task=Tasks.dialog_state_tracking, + model=model, + preprocessor=preprocessor) + ] + + self.tracking_and_print_dialog_states(pipelines) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipelines = [ + pipeline(task=Tasks.dialog_state_tracking, model=self.model_id) + ] + self.tracking_and_print_dialog_states(pipelines) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipelines = [pipeline(task=Tasks.dialog_state_tracking)] + self.tracking_and_print_dialog_states(pipelines) + + +if __name__ == '__main__': + unittest.main()