code review:https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9227018master
| @@ -62,6 +62,7 @@ class Pipelines(object): | |||
| nli = 'nli' | |||
| dialog_intent_prediction = 'dialog-intent-prediction' | |||
| dialog_modeling = 'dialog-modeling' | |||
| dialog_state_tracking = 'dialog-state-tracking' | |||
| zero_shot_classification = 'zero-shot-classification' | |||
| # audio tasks | |||
| @@ -112,6 +113,7 @@ class Preprocessors(object): | |||
| sen_cls_tokenizer = 'sen-cls-tokenizer' | |||
| dialog_intent_preprocessor = 'dialog-intent-preprocessor' | |||
| dialog_modeling_preprocessor = 'dialog-modeling-preprocessor' | |||
| dialog_state_tracking_preprocessor = 'dialog-state-tracking-preprocessor' | |||
| sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' | |||
| zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | |||
| @@ -15,12 +15,14 @@ except ModuleNotFoundError as e: | |||
| try: | |||
| from .audio.kws import GenericKeyWordSpotting | |||
| from .multi_modal import OfaForImageCaptioning | |||
| from .nlp import ( | |||
| BertForMaskedLM, BertForSequenceClassification, CsanmtForTranslation, | |||
| SbertForNLI, SbertForSentenceSimilarity, | |||
| SbertForSentimentClassification, SbertForTokenClassification, | |||
| SbertForZeroShotClassification, SpaceForDialogIntent, | |||
| SpaceForDialogModeling, StructBertForMaskedLM, VecoForMaskedLM) | |||
| from .nlp import (BertForMaskedLM, BertForSequenceClassification, | |||
| CsanmtForTranslation, SbertForNLI, | |||
| SbertForSentenceSimilarity, | |||
| SbertForSentimentClassification, | |||
| SbertForTokenClassification, | |||
| SbertForZeroShotClassification, SpaceForDialogIntent, | |||
| SpaceForDialogModeling, SpaceForDialogStateTracking, | |||
| StructBertForMaskedLM, VecoForMaskedLM) | |||
| from .audio.ans.frcrn import FRCRNModel | |||
| except ModuleNotFoundError as e: | |||
| if str(e) == "No module named 'pytorch'": | |||
| @@ -9,3 +9,4 @@ from .sbert_for_token_classification import * # noqa F403 | |||
| from .sbert_for_zero_shot_classification import * # noqa F403 | |||
| from .space.dialog_intent_prediction_model import * # noqa F403 | |||
| from .space.dialog_modeling_model import * # noqa F403 | |||
| from .space.dialog_state_tracking_model import * # noqa F403 | |||
| @@ -63,15 +63,16 @@ class SpaceForDialogIntent(Model): | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Any]): the preprocessed data | |||
| input (Dict[str, Tensor]): the preprocessed data | |||
| Returns: | |||
| Dict[str, np.ndarray]: results | |||
| Dict[str, Tensor]: results | |||
| Example: | |||
| { | |||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||
| 'pred': array([2.62349960e-03 4.12110658e-03 4.12748595e-05 3.77560973e-05 | |||
| 1.08599677e-04 1.72710388e-05 2.95618793e-05 1.93638436e-04 | |||
| 6.45841064e-05 1.15997791e-04 5.11605394e-05 9.87020373e-01 | |||
| 2.66957268e-05 4.72324500e-05 9.74208378e-05], dtype=float32) | |||
| } | |||
| """ | |||
| import numpy as np | |||
| @@ -62,15 +62,17 @@ class SpaceForDialogModeling(Model): | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Any]): the preprocessed data | |||
| input (Dict[str, Tensor]): the preprocessed data | |||
| Returns: | |||
| Dict[str, np.ndarray]: results | |||
| Dict[str, Tensor]: results | |||
| Example: | |||
| { | |||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||
| 'labels': array([1,192,321,12]), # lable | |||
| 'resp': array([293,1023,123,1123]), #vocab label for response | |||
| 'bspn': array([123,321,2,24,1 ]), | |||
| 'aspn': array([47,8345,32,29,1983]), | |||
| 'db': array([19, 24, 20]), | |||
| } | |||
| """ | |||
| @@ -0,0 +1,103 @@ | |||
| import os | |||
| from typing import Any, Dict | |||
| from modelscope.utils.constant import Tasks | |||
| from ....metainfo import Models | |||
| from ....utils.nlp.space.utils_dst import batch_to_device | |||
| from ...base import Model, Tensor | |||
| from ...builder import MODELS | |||
| __all__ = ['SpaceForDialogStateTracking'] | |||
| @MODELS.register_module(Tasks.dialog_state_tracking, module_name=Models.space) | |||
| class SpaceForDialogStateTracking(Model): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the test generation model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| from sofa.models.space import SpaceForDST, SpaceConfig | |||
| self.model_dir = model_dir | |||
| self.config = SpaceConfig.from_pretrained(self.model_dir) | |||
| self.model = SpaceForDST.from_pretrained(self.model_dir) | |||
| self.model.to(self.config.device) | |||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Tensor]): the preprocessed data | |||
| Returns: | |||
| Dict[str, Tensor]: results | |||
| Example: | |||
| { | |||
| 'inputs': dict(input_ids, input_masks,start_pos), # tracking states | |||
| 'outputs': dict(slots_logits), | |||
| 'unique_ids': str(test-example.json-0), # default value | |||
| 'input_ids_unmasked': array([101, 7632, 1010,0,0,0]) | |||
| 'values': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]), | |||
| 'inform': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]), | |||
| 'prefix': str('final'), #default value | |||
| 'ds': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]) | |||
| } | |||
| """ | |||
| import numpy as np | |||
| import torch | |||
| self.model.eval() | |||
| batch = input['batch'] | |||
| batch = batch_to_device(batch, self.config.device) | |||
| features = input['features'] | |||
| diag_state = input['diag_state'] | |||
| turn_itrs = [features[i.item()].guid.split('-')[2] for i in batch[9]] | |||
| reset_diag_state = np.where(np.array(turn_itrs) == '0')[0] | |||
| for slot in self.config.dst_slot_list: | |||
| for i in reset_diag_state: | |||
| diag_state[slot][i] = 0 | |||
| with torch.no_grad(): | |||
| inputs = { | |||
| 'input_ids': batch[0], | |||
| 'input_mask': batch[1], | |||
| 'segment_ids': batch[2], | |||
| 'start_pos': batch[3], | |||
| 'end_pos': batch[4], | |||
| 'inform_slot_id': batch[5], | |||
| 'refer_id': batch[6], | |||
| 'diag_state': diag_state, | |||
| 'class_label_id': batch[8] | |||
| } | |||
| unique_ids = [features[i.item()].guid for i in batch[9]] | |||
| values = [features[i.item()].values for i in batch[9]] | |||
| input_ids_unmasked = [ | |||
| features[i.item()].input_ids_unmasked for i in batch[9] | |||
| ] | |||
| inform = [features[i.item()].inform for i in batch[9]] | |||
| outputs = self.model(**inputs) | |||
| # Update dialog state for next turn. | |||
| for slot in self.config.dst_slot_list: | |||
| updates = outputs[2][slot].max(1)[1] | |||
| for i, u in enumerate(updates): | |||
| if u != 0: | |||
| diag_state[slot][i] = u | |||
| return { | |||
| 'inputs': inputs, | |||
| 'outputs': outputs, | |||
| 'unique_ids': unique_ids, | |||
| 'input_ids_unmasked': input_ids_unmasked, | |||
| 'values': values, | |||
| 'inform': inform, | |||
| 'prefix': 'final', | |||
| 'ds': input['ds'] | |||
| } | |||
| @@ -41,6 +41,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| 'damo/nlp_space_dialog-intent-prediction'), | |||
| Tasks.dialog_modeling: (Pipelines.dialog_modeling, | |||
| 'damo/nlp_space_dialog-modeling'), | |||
| Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, | |||
| 'damo/nlp_space_dialog-state-tracking'), | |||
| Tasks.image_captioning: (Pipelines.image_caption, | |||
| 'damo/ofa_image-caption_coco_large_en'), | |||
| Tasks.image_generation: | |||
| @@ -1,6 +1,7 @@ | |||
| try: | |||
| from .dialog_intent_prediction_pipeline import * # noqa F403 | |||
| from .dialog_modeling_pipeline import * # noqa F403 | |||
| from .dialog_state_tracking_pipeline import * # noqa F403 | |||
| from .fill_mask_pipeline import * # noqa F403 | |||
| from .nli_pipeline import * # noqa F403 | |||
| from .sentence_similarity_pipeline import * # noqa F403 | |||
| @@ -1,8 +1,9 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict | |||
| from typing import Any, Dict, Union | |||
| from ...metainfo import Pipelines | |||
| from ...models import Model | |||
| from ...models.nlp import SpaceForDialogIntent | |||
| from ...preprocessors import DialogIntentPredictionPreprocessor | |||
| from ...utils.constant import Tasks | |||
| @@ -18,17 +19,22 @@ __all__ = ['DialogIntentPredictionPipeline'] | |||
| module_name=Pipelines.dialog_intent_prediction) | |||
| class DialogIntentPredictionPipeline(Pipeline): | |||
| def __init__(self, model: SpaceForDialogIntent, | |||
| preprocessor: DialogIntentPredictionPreprocessor, **kwargs): | |||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||
| def __init__(self, | |||
| model: Union[SpaceForDialogIntent, str], | |||
| preprocessor: DialogIntentPredictionPreprocessor = None, | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create a dialog intent prediction pipeline | |||
| Args: | |||
| model (SequenceClassificationModel): a model instance | |||
| preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||
| model (SpaceForDialogIntent): a model instance | |||
| preprocessor (DialogIntentPredictionPreprocessor): a preprocessor instance | |||
| """ | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| model = model if isinstance( | |||
| model, SpaceForDialogIntent) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = DialogIntentPredictionPreprocessor(model.model_dir) | |||
| self.model = model | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| self.categories = preprocessor.categories | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||
| @@ -1,8 +1,9 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict, Optional | |||
| from typing import Any, Dict, Union | |||
| from ...metainfo import Pipelines | |||
| from ...models import Model | |||
| from ...models.nlp import SpaceForDialogModeling | |||
| from ...preprocessors import DialogModelingPreprocessor | |||
| from ...utils.constant import Tasks | |||
| @@ -17,17 +18,22 @@ __all__ = ['DialogModelingPipeline'] | |||
| Tasks.dialog_modeling, module_name=Pipelines.dialog_modeling) | |||
| class DialogModelingPipeline(Pipeline): | |||
| def __init__(self, model: SpaceForDialogModeling, | |||
| preprocessor: DialogModelingPreprocessor, **kwargs): | |||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||
| def __init__(self, | |||
| model: Union[SpaceForDialogModeling, str], | |||
| preprocessor: DialogModelingPreprocessor = None, | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create a dialog modleing pipeline for dialog response generation | |||
| Args: | |||
| model (SequenceClassificationModel): a model instance | |||
| preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||
| model (SpaceForDialogModeling): a model instance | |||
| preprocessor (DialogModelingPreprocessor): a preprocessor instance | |||
| """ | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| model = model if isinstance( | |||
| model, SpaceForDialogModeling) else Model.from_pretrained(model) | |||
| self.model = model | |||
| if preprocessor is None: | |||
| preprocessor = DialogModelingPreprocessor(model.model_dir) | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| self.preprocessor = preprocessor | |||
| def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: | |||
| @@ -43,7 +49,6 @@ class DialogModelingPipeline(Pipeline): | |||
| inputs['resp']) | |||
| assert len(sys_rsp) > 2 | |||
| sys_rsp = sys_rsp[1:len(sys_rsp) - 1] | |||
| inputs[OutputKeys.RESPONSE] = sys_rsp | |||
| return inputs | |||
| @@ -0,0 +1,159 @@ | |||
| from typing import Any, Dict, Union | |||
| from ...metainfo import Pipelines | |||
| from ...models import Model, SpaceForDialogStateTracking | |||
| from ...preprocessors import DialogStateTrackingPreprocessor | |||
| from ...utils.constant import Tasks | |||
| from ..base import Pipeline | |||
| from ..builder import PIPELINES | |||
| from ..outputs import OutputKeys | |||
| __all__ = ['DialogStateTrackingPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.dialog_state_tracking, module_name=Pipelines.dialog_state_tracking) | |||
| class DialogStateTrackingPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[SpaceForDialogStateTracking, str], | |||
| preprocessor: DialogStateTrackingPreprocessor = None, | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create a dialog state tracking pipeline for | |||
| observation of dialog states tracking after many turns of open domain dialogue | |||
| Args: | |||
| model (SpaceForDialogStateTracking): a model instance | |||
| preprocessor (DialogStateTrackingPreprocessor): a preprocessor instance | |||
| """ | |||
| model = model if isinstance( | |||
| model, | |||
| SpaceForDialogStateTracking) else Model.from_pretrained(model) | |||
| self.model = model | |||
| if preprocessor is None: | |||
| preprocessor = DialogStateTrackingPreprocessor(model.model_dir) | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| self.tokenizer = preprocessor.tokenizer | |||
| self.config = preprocessor.config | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| _inputs = inputs['inputs'] | |||
| _outputs = inputs['outputs'] | |||
| unique_ids = inputs['unique_ids'] | |||
| input_ids_unmasked = inputs['input_ids_unmasked'] | |||
| values = inputs['values'] | |||
| inform = inputs['inform'] | |||
| prefix = inputs['prefix'] | |||
| ds = inputs['ds'] | |||
| ds = predict_and_format(self.config, self.tokenizer, _inputs, | |||
| _outputs[2], _outputs[3], _outputs[4], | |||
| _outputs[5], unique_ids, input_ids_unmasked, | |||
| values, inform, prefix, ds) | |||
| return {OutputKeys.DIALOG_STATES: ds} | |||
| def predict_and_format(config, tokenizer, features, per_slot_class_logits, | |||
| per_slot_start_logits, per_slot_end_logits, | |||
| per_slot_refer_logits, ids, input_ids_unmasked, values, | |||
| inform, prefix, ds): | |||
| import re | |||
| prediction_list = [] | |||
| dialog_state = ds | |||
| for i in range(len(ids)): | |||
| if int(ids[i].split('-')[2]) == 0: | |||
| dialog_state = {slot: 'none' for slot in config.dst_slot_list} | |||
| prediction = {} | |||
| prediction_addendum = {} | |||
| for slot in config.dst_slot_list: | |||
| class_logits = per_slot_class_logits[slot][i] | |||
| start_logits = per_slot_start_logits[slot][i] | |||
| end_logits = per_slot_end_logits[slot][i] | |||
| refer_logits = per_slot_refer_logits[slot][i] | |||
| input_ids = features['input_ids'][i].tolist() | |||
| class_label_id = int(features['class_label_id'][slot][i]) | |||
| start_pos = int(features['start_pos'][slot][i]) | |||
| end_pos = int(features['end_pos'][slot][i]) | |||
| refer_id = int(features['refer_id'][slot][i]) | |||
| class_prediction = int(class_logits.argmax()) | |||
| start_prediction = int(start_logits.argmax()) | |||
| end_prediction = int(end_logits.argmax()) | |||
| refer_prediction = int(refer_logits.argmax()) | |||
| prediction['guid'] = ids[i].split('-') | |||
| prediction['class_prediction_%s' % slot] = class_prediction | |||
| prediction['class_label_id_%s' % slot] = class_label_id | |||
| prediction['start_prediction_%s' % slot] = start_prediction | |||
| prediction['start_pos_%s' % slot] = start_pos | |||
| prediction['end_prediction_%s' % slot] = end_prediction | |||
| prediction['end_pos_%s' % slot] = end_pos | |||
| prediction['refer_prediction_%s' % slot] = refer_prediction | |||
| prediction['refer_id_%s' % slot] = refer_id | |||
| prediction['input_ids_%s' % slot] = input_ids | |||
| if class_prediction == config.dst_class_types.index('dontcare'): | |||
| dialog_state[slot] = 'dontcare' | |||
| elif class_prediction == config.dst_class_types.index( | |||
| 'copy_value'): | |||
| input_tokens = tokenizer.convert_ids_to_tokens( | |||
| input_ids_unmasked[i]) | |||
| dialog_state[slot] = ' '.join( | |||
| input_tokens[start_prediction:end_prediction + 1]) | |||
| dialog_state[slot] = re.sub('(^| )##', '', dialog_state[slot]) | |||
| elif 'true' in config.dst_class_types and class_prediction == config.dst_class_types.index( | |||
| 'true'): | |||
| dialog_state[slot] = 'true' | |||
| elif 'false' in config.dst_class_types and class_prediction == config.dst_class_types.index( | |||
| 'false'): | |||
| dialog_state[slot] = 'false' | |||
| elif class_prediction == config.dst_class_types.index('inform'): | |||
| # dialog_state[slot] = '§§' + inform[i][slot] | |||
| if isinstance(inform[i][slot], str): | |||
| dialog_state[slot] = inform[i][slot] | |||
| elif isinstance(inform[i][slot], list): | |||
| dialog_state[slot] = inform[i][slot][0] | |||
| # Referral case is handled below | |||
| prediction_addendum['slot_prediction_%s' | |||
| % slot] = dialog_state[slot] | |||
| prediction_addendum['slot_groundtruth_%s' % slot] = values[i][slot] | |||
| # Referral case. All other slot values need to be seen first in order | |||
| # to be able to do this correctly. | |||
| for slot in config.dst_slot_list: | |||
| class_logits = per_slot_class_logits[slot][i] | |||
| refer_logits = per_slot_refer_logits[slot][i] | |||
| class_prediction = int(class_logits.argmax()) | |||
| refer_prediction = int(refer_logits.argmax()) | |||
| if 'refer' in config.dst_class_types and class_prediction == config.dst_class_types.index( | |||
| 'refer'): | |||
| # Only slots that have been mentioned before can be referred to. | |||
| # One can think of a situation where one slot is referred to in the same utterance. | |||
| # This phenomenon is however currently not properly covered in the training data | |||
| # label generation process. | |||
| dialog_state[slot] = dialog_state[config.dst_slot_list[ | |||
| refer_prediction - 1]] | |||
| prediction_addendum['slot_prediction_%s' % | |||
| slot] = dialog_state[slot] # Value update | |||
| prediction.update(prediction_addendum) | |||
| prediction_list.append(prediction) | |||
| return dialog_state | |||
| @@ -74,5 +74,4 @@ class SentimentClassificationPipeline(Pipeline): | |||
| probs = probs[cls_ids].tolist() | |||
| cls_names = [self.model.id2label[cid] for cid in cls_ids] | |||
| return {OutputKeys.SCORES: probs, OutputKeys.LABELS: cls_names} | |||
| @@ -29,7 +29,6 @@ class ZeroShotClassificationPipeline(Pipeline): | |||
| preprocessor: ZeroShotClassificationPreprocessor = None, | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||
| Args: | |||
| model (SbertForSentimentClassification): a model instance | |||
| preprocessor (SentimentClassificationPreprocessor): a preprocessor instance | |||
| @@ -39,10 +38,8 @@ class ZeroShotClassificationPipeline(Pipeline): | |||
| model = model if isinstance( | |||
| model, | |||
| SbertForZeroShotClassification) else Model.from_pretrained(model) | |||
| self.entailment_id = 0 | |||
| self.contradiction_id = 2 | |||
| if preprocessor is None: | |||
| preprocessor = ZeroShotClassificationPreprocessor(model.model_dir) | |||
| model.eval() | |||
| @@ -51,7 +48,6 @@ class ZeroShotClassificationPipeline(Pipeline): | |||
| def _sanitize_parameters(self, **kwargs): | |||
| preprocess_params = {} | |||
| postprocess_params = {} | |||
| if 'candidate_labels' in kwargs: | |||
| candidate_labels = kwargs.pop('candidate_labels') | |||
| preprocess_params['candidate_labels'] = candidate_labels | |||
| @@ -60,7 +56,6 @@ class ZeroShotClassificationPipeline(Pipeline): | |||
| raise ValueError('You must include at least one label.') | |||
| preprocess_params['hypothesis_template'] = kwargs.pop( | |||
| 'hypothesis_template', '{}') | |||
| postprocess_params['multi_label'] = kwargs.pop('multi_label', False) | |||
| return preprocess_params, {}, postprocess_params | |||
| @@ -74,14 +69,11 @@ class ZeroShotClassificationPipeline(Pipeline): | |||
| candidate_labels, | |||
| multi_label=False) -> Dict[str, Any]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, Any]: the prediction results | |||
| """ | |||
| logits = inputs['logits'] | |||
| if multi_label or len(candidate_labels) == 1: | |||
| logits = logits[..., [self.contradiction_id, self.entailment_id]] | |||
| @@ -89,7 +81,6 @@ class ZeroShotClassificationPipeline(Pipeline): | |||
| else: | |||
| logits = logits[..., self.entailment_id] | |||
| scores = softmax(logits, axis=-1) | |||
| reversed_index = list(reversed(scores.argsort())) | |||
| result = { | |||
| OutputKeys.LABELS: [candidate_labels[i] for i in reversed_index], | |||
| @@ -21,6 +21,7 @@ class OutputKeys(object): | |||
| TRANSLATION = 'translation' | |||
| RESPONSE = 'response' | |||
| PREDICTION = 'prediction' | |||
| DIALOG_STATES = 'dialog_states' | |||
| VIDEO_EMBEDDING = 'video_embedding' | |||
| @@ -158,6 +159,7 @@ TASK_OUTPUTS = { | |||
| # } | |||
| Tasks.nli: [OutputKeys.SCORES, OutputKeys.LABELS], | |||
| # dialog intent prediction result for single sample | |||
| # {'pred': array([2.62349960e-03, 4.12110658e-03, 4.12748595e-05, 3.77560973e-05, | |||
| # 1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04, | |||
| # 6.45841064e-05, 1.15997791e-04, 5.11605394e-05, 9.87020373e-01, | |||
| @@ -181,9 +183,47 @@ TASK_OUTPUTS = { | |||
| Tasks.dialog_intent_prediction: | |||
| [OutputKeys.PREDICTION, OutputKeys.LABEL_POS, OutputKeys.LABEL], | |||
| # dialog modeling prediction result for single sample | |||
| # sys : ['you', 'are', 'welcome', '.', 'have', 'a', 'great', 'day', '!'] | |||
| Tasks.dialog_modeling: [OutputKeys.RESPONSE], | |||
| # dialog state tracking result for single sample | |||
| # { | |||
| # "dialog_states": { | |||
| # "taxi-leaveAt": "none", | |||
| # "taxi-destination": "none", | |||
| # "taxi-departure": "none", | |||
| # "taxi-arriveBy": "none", | |||
| # "restaurant-book_people": "none", | |||
| # "restaurant-book_day": "none", | |||
| # "restaurant-book_time": "none", | |||
| # "restaurant-food": "none", | |||
| # "restaurant-pricerange": "none", | |||
| # "restaurant-name": "none", | |||
| # "restaurant-area": "none", | |||
| # "hotel-book_people": "none", | |||
| # "hotel-book_day": "none", | |||
| # "hotel-book_stay": "none", | |||
| # "hotel-name": "none", | |||
| # "hotel-area": "none", | |||
| # "hotel-parking": "none", | |||
| # "hotel-pricerange": "cheap", | |||
| # "hotel-stars": "none", | |||
| # "hotel-internet": "none", | |||
| # "hotel-type": "true", | |||
| # "attraction-type": "none", | |||
| # "attraction-name": "none", | |||
| # "attraction-area": "none", | |||
| # "train-book_people": "none", | |||
| # "train-leaveAt": "none", | |||
| # "train-destination": "none", | |||
| # "train-day": "none", | |||
| # "train-arriveBy": "none", | |||
| # "train-departure": "none" | |||
| # } | |||
| # } | |||
| Tasks.dialog_state_tracking: [OutputKeys.DIALOG_STATES], | |||
| # ============ audio tasks =================== | |||
| # audio processed for single file in PCM format | |||
| @@ -13,6 +13,7 @@ try: | |||
| from .nlp import * # noqa F403 | |||
| from .space.dialog_intent_prediction_preprocessor import * # noqa F403 | |||
| from .space.dialog_modeling_preprocessor import * # noqa F403 | |||
| from .space.dialog_state_tracking_preprocessor import * # noqa F403 | |||
| except ModuleNotFoundError as e: | |||
| if str(e) == "No module named 'tensorflow'": | |||
| pass | |||
| @@ -0,0 +1,133 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from typing import Any, Dict | |||
| from modelscope.utils.constant import Fields | |||
| from modelscope.utils.type_assert import type_assert | |||
| from ...metainfo import Preprocessors | |||
| from ..base import Preprocessor | |||
| from ..builder import PREPROCESSORS | |||
| from .dst_processors import convert_examples_to_features, multiwoz22Processor | |||
| __all__ = ['DialogStateTrackingPreprocessor'] | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.dialog_state_tracking_preprocessor) | |||
| class DialogStateTrackingPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| from sofa.models.space import SpaceTokenizer, SpaceConfig | |||
| self.model_dir: str = model_dir | |||
| self.config = SpaceConfig.from_pretrained(self.model_dir) | |||
| self.tokenizer = SpaceTokenizer.from_pretrained(self.model_dir) | |||
| self.processor = multiwoz22Processor() | |||
| @type_assert(object, dict) | |||
| def __call__(self, data: Dict) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data (str): a sentence | |||
| Example: | |||
| 'you are so handsome.' | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| import torch | |||
| from torch.utils.data import (DataLoader, RandomSampler, | |||
| SequentialSampler) | |||
| utter = data['utter'] | |||
| history_states = data['history_states'] | |||
| example = self.processor.create_example( | |||
| inputs=utter, | |||
| history_states=history_states, | |||
| set_type='test', | |||
| slot_list=self.config.dst_slot_list, | |||
| label_maps={}, | |||
| append_history=True, | |||
| use_history_labels=True, | |||
| swap_utterances=True, | |||
| label_value_repetitions=True, | |||
| delexicalize_sys_utts=True, | |||
| unk_token='[UNK]', | |||
| analyze=False) | |||
| features = convert_examples_to_features( | |||
| examples=[example], | |||
| slot_list=self.config.dst_slot_list, | |||
| class_types=self.config.dst_class_types, | |||
| model_type=self.config.model_type, | |||
| tokenizer=self.tokenizer, | |||
| max_seq_length=180, # args.max_seq_length | |||
| slot_value_dropout=(0.0)) | |||
| all_input_ids = torch.tensor([f.input_ids for f in features], | |||
| dtype=torch.long) | |||
| all_input_mask = torch.tensor([f.input_mask for f in features], | |||
| dtype=torch.long) | |||
| all_segment_ids = torch.tensor([f.segment_ids for f in features], | |||
| dtype=torch.long) | |||
| all_example_index = torch.arange( | |||
| all_input_ids.size(0), dtype=torch.long) | |||
| f_start_pos = [f.start_pos for f in features] | |||
| f_end_pos = [f.end_pos for f in features] | |||
| f_inform_slot_ids = [f.inform_slot for f in features] | |||
| f_refer_ids = [f.refer_id for f in features] | |||
| f_diag_state = [f.diag_state for f in features] | |||
| f_class_label_ids = [f.class_label_id for f in features] | |||
| all_start_positions = {} | |||
| all_end_positions = {} | |||
| all_inform_slot_ids = {} | |||
| all_refer_ids = {} | |||
| all_diag_state = {} | |||
| all_class_label_ids = {} | |||
| for s in self.config.dst_slot_list: | |||
| all_start_positions[s] = torch.tensor([f[s] for f in f_start_pos], | |||
| dtype=torch.long) | |||
| all_end_positions[s] = torch.tensor([f[s] for f in f_end_pos], | |||
| dtype=torch.long) | |||
| all_inform_slot_ids[s] = torch.tensor( | |||
| [f[s] for f in f_inform_slot_ids], dtype=torch.long) | |||
| all_refer_ids[s] = torch.tensor([f[s] for f in f_refer_ids], | |||
| dtype=torch.long) | |||
| all_diag_state[s] = torch.tensor([f[s] for f in f_diag_state], | |||
| dtype=torch.long) | |||
| all_class_label_ids[s] = torch.tensor( | |||
| [f[s] for f in f_class_label_ids], dtype=torch.long) | |||
| dataset = [ | |||
| all_input_ids, all_input_mask, all_segment_ids, | |||
| all_start_positions, all_end_positions, all_inform_slot_ids, | |||
| all_refer_ids, all_diag_state, all_class_label_ids, | |||
| all_example_index | |||
| ] | |||
| with torch.no_grad(): | |||
| diag_state = { | |||
| slot: | |||
| torch.tensor([0 for _ in range(self.config.eval_batch_size) | |||
| ]).to(self.config.device) | |||
| for slot in self.config.dst_slot_list | |||
| } | |||
| if len(history_states) > 2: | |||
| ds = history_states[-2] | |||
| else: | |||
| ds = {slot: 'none' for slot in self.config.dst_slot_list} | |||
| return { | |||
| 'batch': dataset, | |||
| 'features': features, | |||
| 'diag_state': diag_state, | |||
| 'ds': ds | |||
| } | |||
| @@ -0,0 +1,59 @@ | |||
| # | |||
| # Copyright 2020 Heinrich Heine University Duesseldorf | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from torch.utils.data import Dataset | |||
| class TensorListDataset(Dataset): | |||
| r"""Dataset wrapping tensors, tensor dicts and tensor lists. | |||
| Arguments: | |||
| *data (Tensor or dict or list of Tensors): tensors that have the same size | |||
| of the first dimension. | |||
| """ | |||
| def __init__(self, *data): | |||
| if isinstance(data[0], dict): | |||
| size = list(data[0].values())[0].size(0) | |||
| elif isinstance(data[0], list): | |||
| size = data[0][0].size(0) | |||
| else: | |||
| size = data[0].size(0) | |||
| for element in data: | |||
| if isinstance(element, dict): | |||
| assert all( | |||
| size == tensor.size(0) | |||
| for name, tensor in element.items()) # dict of tensors | |||
| elif isinstance(element, list): | |||
| assert all(size == tensor.size(0) | |||
| for tensor in element) # list of tensors | |||
| else: | |||
| assert size == element.size(0) # tensor | |||
| self.size = size | |||
| self.data = data | |||
| def __getitem__(self, index): | |||
| result = [] | |||
| for element in self.data: | |||
| if isinstance(element, dict): | |||
| result.append({k: v[index] for k, v in element.items()}) | |||
| elif isinstance(element, list): | |||
| result.append(v[index] for v in element) | |||
| else: | |||
| result.append(element[index]) | |||
| return tuple(result) | |||
| def __len__(self): | |||
| return self.size | |||
| @@ -48,6 +48,7 @@ class Tasks(object): | |||
| text_generation = 'text-generation' | |||
| dialog_modeling = 'dialog-modeling' | |||
| dialog_intent_prediction = 'dialog-intent-prediction' | |||
| dialog_state_tracking = 'dialog-state-tracking' | |||
| table_question_answering = 'table-question-answering' | |||
| feature_extraction = 'feature-extraction' | |||
| fill_mask = 'fill-mask' | |||
| @@ -0,0 +1,10 @@ | |||
| def batch_to_device(batch, device): | |||
| batch_on_device = [] | |||
| for element in batch: | |||
| if isinstance(element, dict): | |||
| batch_on_device.append( | |||
| {k: v.to(device) | |||
| for k, v in element.items()}) | |||
| else: | |||
| batch_on_device.append(element.to(device)) | |||
| return tuple(batch_on_device) | |||
| @@ -1,3 +1,3 @@ | |||
| https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz | |||
| http://ait-public.oss-cn-hangzhou-zmf.aliyuncs.com/jizhu/en_core_web_sm-2.3.1.tar.gz | |||
| sofa==1.0.5 | |||
| spacy>=2.3.5 | |||
| @@ -18,7 +18,7 @@ class DialogIntentPredictionTest(unittest.TestCase): | |||
| ] | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run(self): | |||
| def test_run_by_direct_model_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path) | |||
| model = SpaceForDialogIntent( | |||
| @@ -56,6 +56,20 @@ class DialogIntentPredictionTest(unittest.TestCase): | |||
| for my_pipeline, item in list(zip(pipelines, self.test_case)): | |||
| print(my_pipeline(item)) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipelines = [ | |||
| pipeline(task=Tasks.dialog_intent_prediction, model=self.model_id) | |||
| ] | |||
| for my_pipeline, item in list(zip(pipelines, self.test_case)): | |||
| print(my_pipeline(item)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipelines = [pipeline(task=Tasks.dialog_intent_prediction)] | |||
| for my_pipeline, item in list(zip(pipelines, self.test_case)): | |||
| print(my_pipeline(item)) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -1,5 +1,6 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from typing import List | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| @@ -89,8 +90,22 @@ class DialogModelingTest(unittest.TestCase): | |||
| } | |||
| } | |||
| def generate_and_print_dialog_response( | |||
| self, pipelines: List[DialogModelingPipeline]): | |||
| result = {} | |||
| for step, item in enumerate(self.test_case['sng0073']['log']): | |||
| user = item['user'] | |||
| print('user: {}'.format(user)) | |||
| result = pipelines[step % 2]({ | |||
| 'user_input': user, | |||
| 'history': result | |||
| }) | |||
| print('response : {}'.format(result['response'])) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run(self): | |||
| def test_run_by_direct_model_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| @@ -106,17 +121,7 @@ class DialogModelingTest(unittest.TestCase): | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| ] | |||
| result = {} | |||
| for step, item in enumerate(self.test_case['sng0073']['log']): | |||
| user = item['user'] | |||
| print('user: {}'.format(user)) | |||
| result = pipelines[step % 2]({ | |||
| 'user_input': user, | |||
| 'history': result | |||
| }) | |||
| print('response : {}'.format(result['response'])) | |||
| self.generate_and_print_dialog_response(pipelines) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| @@ -131,16 +136,23 @@ class DialogModelingTest(unittest.TestCase): | |||
| preprocessor=preprocessor) | |||
| ] | |||
| result = {} | |||
| for step, item in enumerate(self.test_case['sng0073']['log']): | |||
| user = item['user'] | |||
| print('user: {}'.format(user)) | |||
| self.generate_and_print_dialog_response(pipelines) | |||
| result = pipelines[step % 2]({ | |||
| 'user_input': user, | |||
| 'history': result | |||
| }) | |||
| print('response : {}'.format(result['response'])) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipelines = [ | |||
| pipeline(task=Tasks.dialog_modeling, model=self.model_id), | |||
| pipeline(task=Tasks.dialog_modeling, model=self.model_id) | |||
| ] | |||
| self.generate_and_print_dialog_response(pipelines) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipelines = [ | |||
| pipeline(task=Tasks.dialog_modeling), | |||
| pipeline(task=Tasks.dialog_modeling) | |||
| ] | |||
| self.generate_and_print_dialog_response(pipelines) | |||
| if __name__ == '__main__': | |||
| @@ -0,0 +1,143 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from typing import List | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model, SpaceForDialogStateTracking | |||
| from modelscope.pipelines import DialogStateTrackingPipeline, pipeline | |||
| from modelscope.preprocessors import DialogStateTrackingPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class DialogStateTrackingTest(unittest.TestCase): | |||
| model_id = 'damo/nlp_space_dialog-state-tracking' | |||
| test_case = [{ | |||
| 'User-1': | |||
| 'Hi, I\'m looking for a train that is going to cambridge and arriving there by 20:45, ' | |||
| 'is there anything like that?' | |||
| }, { | |||
| 'System-1': | |||
| 'There are over 1,000 trains like that. Where will you be departing from?', | |||
| 'Dialog_Act-1': { | |||
| 'Train-Inform': [['Choice', 'over 1'], ['Choice', '000']], | |||
| 'Train-Request': [['Depart', '?']] | |||
| }, | |||
| 'User-2': 'I am departing from birmingham new street.' | |||
| }, { | |||
| 'System-2': 'Can you confirm your desired travel day?', | |||
| 'Dialog_Act-2': { | |||
| 'Train-Request': [['Day', '?']] | |||
| }, | |||
| 'User-3': 'I would like to leave on wednesday' | |||
| }, { | |||
| 'System-3': | |||
| 'I show a train leaving birmingham new street at 17:40 and arriving at 20:23 on Wednesday. ' | |||
| 'Will this work for you?', | |||
| 'Dialog_Act-3': { | |||
| 'Train-Inform': [['Arrive', '20:23'], ['Leave', '17:40'], | |||
| ['Day', 'Wednesday'], | |||
| ['Depart', 'birmingham new street']] | |||
| }, | |||
| 'User-4': | |||
| 'That will, yes. Please make a booking for 5 people please.', | |||
| }, { | |||
| 'System-4': | |||
| 'I\'ve booked your train tickets, and your reference number is A9NHSO9Y.', | |||
| 'Dialog_Act-4': { | |||
| 'Train-OfferBooked': [['Ref', 'A9NHSO9Y']] | |||
| }, | |||
| 'User-5': | |||
| 'Thanks so much. I would also need a place to say. ' | |||
| 'I am looking for something with 4 stars and has free wifi.' | |||
| }, { | |||
| 'System-5': | |||
| 'How about the cambridge belfry? ' | |||
| 'It has all the attributes you requested and a great name! ' | |||
| 'Maybe even a real belfry?', | |||
| 'Dialog_Act-5': { | |||
| 'Hotel-Recommend': [['Name', 'the cambridge belfry']] | |||
| }, | |||
| 'User-6': | |||
| 'That sounds great, could you make a booking for me please?', | |||
| }, { | |||
| 'System-6': | |||
| 'What day would you like your booking for?', | |||
| 'Dialog_Act-6': { | |||
| 'Booking-Request': [['Day', '?']] | |||
| }, | |||
| 'User-7': | |||
| 'Please book it for Wednesday for 5 people and 5 nights, please.', | |||
| }, { | |||
| 'System-7': 'Booking was successful. Reference number is : 5NAWGJDC.', | |||
| 'Dialog_Act-7': { | |||
| 'Booking-Book': [['Ref', '5NAWGJDC']] | |||
| }, | |||
| 'User-8': 'Thank you, goodbye', | |||
| }] | |||
| def tracking_and_print_dialog_states( | |||
| self, pipelines: List[DialogStateTrackingPipeline]): | |||
| import json | |||
| pipelines_len = len(pipelines) | |||
| history_states = [{}] | |||
| utter = {} | |||
| for step, item in enumerate(self.test_case): | |||
| utter.update(item) | |||
| result = pipelines[step % pipelines_len]({ | |||
| 'utter': | |||
| utter, | |||
| 'history_states': | |||
| history_states | |||
| }) | |||
| print(json.dumps(result)) | |||
| history_states.extend([result['dialog_states'], {}]) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_by_direct_model_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| model = SpaceForDialogStateTracking(cache_path) | |||
| preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path) | |||
| pipelines = [ | |||
| DialogStateTrackingPipeline( | |||
| model=model, preprocessor=preprocessor), | |||
| pipeline( | |||
| task=Tasks.dialog_state_tracking, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| ] | |||
| self.tracking_and_print_dialog_states(pipelines) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| preprocessor = DialogStateTrackingPreprocessor( | |||
| model_dir=model.model_dir) | |||
| pipelines = [ | |||
| DialogStateTrackingPipeline( | |||
| model=model, preprocessor=preprocessor), | |||
| pipeline( | |||
| task=Tasks.dialog_state_tracking, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| ] | |||
| self.tracking_and_print_dialog_states(pipelines) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipelines = [ | |||
| pipeline(task=Tasks.dialog_state_tracking, model=self.model_id) | |||
| ] | |||
| self.tracking_and_print_dialog_states(pipelines) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipelines = [pipeline(task=Tasks.dialog_state_tracking)] | |||
| self.tracking_and_print_dialog_states(pipelines) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||