Browse Source

[to #42322933] add space dialog-state tracking pipeline

code review:https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9227018
master
Yingda Chen 3 years ago
parent
commit
f8c669e144
24 changed files with 2194 additions and 66 deletions
  1. +2
    -0
      modelscope/metainfo.py
  2. +8
    -6
      modelscope/models/__init__.py
  3. +1
    -0
      modelscope/models/nlp/__init__.py
  4. +6
    -5
      modelscope/models/nlp/space/dialog_intent_prediction_model.py
  5. +7
    -5
      modelscope/models/nlp/space/dialog_modeling_model.py
  6. +103
    -0
      modelscope/models/nlp/space/dialog_state_tracking_model.py
  7. +2
    -0
      modelscope/pipelines/builder.py
  8. +1
    -0
      modelscope/pipelines/nlp/__init__.py
  9. +14
    -8
      modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py
  10. +14
    -9
      modelscope/pipelines/nlp/dialog_modeling_pipeline.py
  11. +159
    -0
      modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py
  12. +0
    -1
      modelscope/pipelines/nlp/sentiment_classification_pipeline.py
  13. +0
    -9
      modelscope/pipelines/nlp/zero_shot_classification_pipeline.py
  14. +40
    -0
      modelscope/pipelines/outputs.py
  15. +1
    -0
      modelscope/preprocessors/__init__.py
  16. +133
    -0
      modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py
  17. +1441
    -0
      modelscope/preprocessors/space/dst_processors.py
  18. +59
    -0
      modelscope/preprocessors/space/tensorlistdataset.py
  19. +1
    -0
      modelscope/utils/constant.py
  20. +10
    -0
      modelscope/utils/nlp/space/utils_dst.py
  21. +1
    -1
      requirements/nlp.txt
  22. +15
    -1
      tests/pipelines/test_dialog_intent_prediction.py
  23. +33
    -21
      tests/pipelines/test_dialog_modeling.py
  24. +143
    -0
      tests/pipelines/test_dialog_state_tracking.py

+ 2
- 0
modelscope/metainfo.py View File

@@ -62,6 +62,7 @@ class Pipelines(object):
nli = 'nli'
dialog_intent_prediction = 'dialog-intent-prediction'
dialog_modeling = 'dialog-modeling'
dialog_state_tracking = 'dialog-state-tracking'
zero_shot_classification = 'zero-shot-classification'

# audio tasks
@@ -112,6 +113,7 @@ class Preprocessors(object):
sen_cls_tokenizer = 'sen-cls-tokenizer'
dialog_intent_preprocessor = 'dialog-intent-preprocessor'
dialog_modeling_preprocessor = 'dialog-modeling-preprocessor'
dialog_state_tracking_preprocessor = 'dialog-state-tracking-preprocessor'
sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer'
zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer'



+ 8
- 6
modelscope/models/__init__.py View File

@@ -15,12 +15,14 @@ except ModuleNotFoundError as e:
try:
from .audio.kws import GenericKeyWordSpotting
from .multi_modal import OfaForImageCaptioning
from .nlp import (
BertForMaskedLM, BertForSequenceClassification, CsanmtForTranslation,
SbertForNLI, SbertForSentenceSimilarity,
SbertForSentimentClassification, SbertForTokenClassification,
SbertForZeroShotClassification, SpaceForDialogIntent,
SpaceForDialogModeling, StructBertForMaskedLM, VecoForMaskedLM)
from .nlp import (BertForMaskedLM, BertForSequenceClassification,
CsanmtForTranslation, SbertForNLI,
SbertForSentenceSimilarity,
SbertForSentimentClassification,
SbertForTokenClassification,
SbertForZeroShotClassification, SpaceForDialogIntent,
SpaceForDialogModeling, SpaceForDialogStateTracking,
StructBertForMaskedLM, VecoForMaskedLM)
from .audio.ans.frcrn import FRCRNModel
except ModuleNotFoundError as e:
if str(e) == "No module named 'pytorch'":


+ 1
- 0
modelscope/models/nlp/__init__.py View File

@@ -9,3 +9,4 @@ from .sbert_for_token_classification import * # noqa F403
from .sbert_for_zero_shot_classification import * # noqa F403
from .space.dialog_intent_prediction_model import * # noqa F403
from .space.dialog_modeling_model import * # noqa F403
from .space.dialog_state_tracking_model import * # noqa F403

+ 6
- 5
modelscope/models/nlp/space/dialog_intent_prediction_model.py View File

@@ -63,15 +63,16 @@ class SpaceForDialogIntent(Model):
"""return the result by the model

Args:
input (Dict[str, Any]): the preprocessed data
input (Dict[str, Tensor]): the preprocessed data

Returns:
Dict[str, np.ndarray]: results
Dict[str, Tensor]: results
Example:
{
'predictions': array([1]), # lable 0-negative 1-positive
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
'pred': array([2.62349960e-03 4.12110658e-03 4.12748595e-05 3.77560973e-05
1.08599677e-04 1.72710388e-05 2.95618793e-05 1.93638436e-04
6.45841064e-05 1.15997791e-04 5.11605394e-05 9.87020373e-01
2.66957268e-05 4.72324500e-05 9.74208378e-05], dtype=float32)
}
"""
import numpy as np


+ 7
- 5
modelscope/models/nlp/space/dialog_modeling_model.py View File

@@ -62,15 +62,17 @@ class SpaceForDialogModeling(Model):
"""return the result by the model

Args:
input (Dict[str, Any]): the preprocessed data
input (Dict[str, Tensor]): the preprocessed data

Returns:
Dict[str, np.ndarray]: results
Dict[str, Tensor]: results
Example:
{
'predictions': array([1]), # lable 0-negative 1-positive
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
'labels': array([1,192,321,12]), # lable
'resp': array([293,1023,123,1123]), #vocab label for response
'bspn': array([123,321,2,24,1 ]),
'aspn': array([47,8345,32,29,1983]),
'db': array([19, 24, 20]),
}
"""



+ 103
- 0
modelscope/models/nlp/space/dialog_state_tracking_model.py View File

@@ -0,0 +1,103 @@
import os
from typing import Any, Dict

from modelscope.utils.constant import Tasks
from ....metainfo import Models
from ....utils.nlp.space.utils_dst import batch_to_device
from ...base import Model, Tensor
from ...builder import MODELS

__all__ = ['SpaceForDialogStateTracking']


@MODELS.register_module(Tasks.dialog_state_tracking, module_name=Models.space)
class SpaceForDialogStateTracking(Model):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the test generation model from the `model_dir` path.

Args:
model_dir (str): the model path.
"""

super().__init__(model_dir, *args, **kwargs)

from sofa.models.space import SpaceForDST, SpaceConfig
self.model_dir = model_dir

self.config = SpaceConfig.from_pretrained(self.model_dir)
self.model = SpaceForDST.from_pretrained(self.model_dir)
self.model.to(self.config.device)

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""return the result by the model

Args:
input (Dict[str, Tensor]): the preprocessed data

Returns:
Dict[str, Tensor]: results
Example:
{
'inputs': dict(input_ids, input_masks,start_pos), # tracking states
'outputs': dict(slots_logits),
'unique_ids': str(test-example.json-0), # default value
'input_ids_unmasked': array([101, 7632, 1010,0,0,0])
'values': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]),
'inform': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}]),
'prefix': str('final'), #default value
'ds': array([{'taxi-leaveAt': 'none', 'taxi-destination': 'none'}])
}
"""
import numpy as np
import torch

self.model.eval()
batch = input['batch']
batch = batch_to_device(batch, self.config.device)

features = input['features']
diag_state = input['diag_state']
turn_itrs = [features[i.item()].guid.split('-')[2] for i in batch[9]]
reset_diag_state = np.where(np.array(turn_itrs) == '0')[0]
for slot in self.config.dst_slot_list:
for i in reset_diag_state:
diag_state[slot][i] = 0

with torch.no_grad():
inputs = {
'input_ids': batch[0],
'input_mask': batch[1],
'segment_ids': batch[2],
'start_pos': batch[3],
'end_pos': batch[4],
'inform_slot_id': batch[5],
'refer_id': batch[6],
'diag_state': diag_state,
'class_label_id': batch[8]
}
unique_ids = [features[i.item()].guid for i in batch[9]]
values = [features[i.item()].values for i in batch[9]]
input_ids_unmasked = [
features[i.item()].input_ids_unmasked for i in batch[9]
]
inform = [features[i.item()].inform for i in batch[9]]
outputs = self.model(**inputs)

# Update dialog state for next turn.
for slot in self.config.dst_slot_list:
updates = outputs[2][slot].max(1)[1]
for i, u in enumerate(updates):
if u != 0:
diag_state[slot][i] = u

return {
'inputs': inputs,
'outputs': outputs,
'unique_ids': unique_ids,
'input_ids_unmasked': input_ids_unmasked,
'values': values,
'inform': inform,
'prefix': 'final',
'ds': input['ds']
}

+ 2
- 0
modelscope/pipelines/builder.py View File

@@ -41,6 +41,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/nlp_space_dialog-intent-prediction'),
Tasks.dialog_modeling: (Pipelines.dialog_modeling,
'damo/nlp_space_dialog-modeling'),
Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking,
'damo/nlp_space_dialog-state-tracking'),
Tasks.image_captioning: (Pipelines.image_caption,
'damo/ofa_image-caption_coco_large_en'),
Tasks.image_generation:


+ 1
- 0
modelscope/pipelines/nlp/__init__.py View File

@@ -1,6 +1,7 @@
try:
from .dialog_intent_prediction_pipeline import * # noqa F403
from .dialog_modeling_pipeline import * # noqa F403
from .dialog_state_tracking_pipeline import * # noqa F403
from .fill_mask_pipeline import * # noqa F403
from .nli_pipeline import * # noqa F403
from .sentence_similarity_pipeline import * # noqa F403


+ 14
- 8
modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py View File

@@ -1,8 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Any, Dict
from typing import Any, Dict, Union

from ...metainfo import Pipelines
from ...models import Model
from ...models.nlp import SpaceForDialogIntent
from ...preprocessors import DialogIntentPredictionPreprocessor
from ...utils.constant import Tasks
@@ -18,17 +19,22 @@ __all__ = ['DialogIntentPredictionPipeline']
module_name=Pipelines.dialog_intent_prediction)
class DialogIntentPredictionPipeline(Pipeline):

def __init__(self, model: SpaceForDialogIntent,
preprocessor: DialogIntentPredictionPreprocessor, **kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction
def __init__(self,
model: Union[SpaceForDialogIntent, str],
preprocessor: DialogIntentPredictionPreprocessor = None,
**kwargs):
"""use `model` and `preprocessor` to create a dialog intent prediction pipeline

Args:
model (SequenceClassificationModel): a model instance
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
model (SpaceForDialogIntent): a model instance
preprocessor (DialogIntentPredictionPreprocessor): a preprocessor instance
"""

super().__init__(model=model, preprocessor=preprocessor, **kwargs)
model = model if isinstance(
model, SpaceForDialogIntent) else Model.from_pretrained(model)
if preprocessor is None:
preprocessor = DialogIntentPredictionPreprocessor(model.model_dir)
self.model = model
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.categories = preprocessor.categories

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:


+ 14
- 9
modelscope/pipelines/nlp/dialog_modeling_pipeline.py View File

@@ -1,8 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Any, Dict, Optional
from typing import Any, Dict, Union

from ...metainfo import Pipelines
from ...models import Model
from ...models.nlp import SpaceForDialogModeling
from ...preprocessors import DialogModelingPreprocessor
from ...utils.constant import Tasks
@@ -17,17 +18,22 @@ __all__ = ['DialogModelingPipeline']
Tasks.dialog_modeling, module_name=Pipelines.dialog_modeling)
class DialogModelingPipeline(Pipeline):

def __init__(self, model: SpaceForDialogModeling,
preprocessor: DialogModelingPreprocessor, **kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction
def __init__(self,
model: Union[SpaceForDialogModeling, str],
preprocessor: DialogModelingPreprocessor = None,
**kwargs):
"""use `model` and `preprocessor` to create a dialog modleing pipeline for dialog response generation

Args:
model (SequenceClassificationModel): a model instance
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
model (SpaceForDialogModeling): a model instance
preprocessor (DialogModelingPreprocessor): a preprocessor instance
"""
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
model = model if isinstance(
model, SpaceForDialogModeling) else Model.from_pretrained(model)
self.model = model
if preprocessor is None:
preprocessor = DialogModelingPreprocessor(model.model_dir)
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.preprocessor = preprocessor

def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]:
@@ -43,7 +49,6 @@ class DialogModelingPipeline(Pipeline):
inputs['resp'])
assert len(sys_rsp) > 2
sys_rsp = sys_rsp[1:len(sys_rsp) - 1]

inputs[OutputKeys.RESPONSE] = sys_rsp

return inputs

+ 159
- 0
modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py View File

@@ -0,0 +1,159 @@
from typing import Any, Dict, Union

from ...metainfo import Pipelines
from ...models import Model, SpaceForDialogStateTracking
from ...preprocessors import DialogStateTrackingPreprocessor
from ...utils.constant import Tasks
from ..base import Pipeline
from ..builder import PIPELINES
from ..outputs import OutputKeys

__all__ = ['DialogStateTrackingPipeline']


@PIPELINES.register_module(
Tasks.dialog_state_tracking, module_name=Pipelines.dialog_state_tracking)
class DialogStateTrackingPipeline(Pipeline):

def __init__(self,
model: Union[SpaceForDialogStateTracking, str],
preprocessor: DialogStateTrackingPreprocessor = None,
**kwargs):
"""use `model` and `preprocessor` to create a dialog state tracking pipeline for
observation of dialog states tracking after many turns of open domain dialogue

Args:
model (SpaceForDialogStateTracking): a model instance
preprocessor (DialogStateTrackingPreprocessor): a preprocessor instance
"""

model = model if isinstance(
model,
SpaceForDialogStateTracking) else Model.from_pretrained(model)
self.model = model
if preprocessor is None:
preprocessor = DialogStateTrackingPreprocessor(model.model_dir)
super().__init__(model=model, preprocessor=preprocessor, **kwargs)

self.tokenizer = preprocessor.tokenizer
self.config = preprocessor.config

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""process the prediction results

Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, str]: the prediction results
"""

_inputs = inputs['inputs']
_outputs = inputs['outputs']
unique_ids = inputs['unique_ids']
input_ids_unmasked = inputs['input_ids_unmasked']
values = inputs['values']
inform = inputs['inform']
prefix = inputs['prefix']
ds = inputs['ds']
ds = predict_and_format(self.config, self.tokenizer, _inputs,
_outputs[2], _outputs[3], _outputs[4],
_outputs[5], unique_ids, input_ids_unmasked,
values, inform, prefix, ds)

return {OutputKeys.DIALOG_STATES: ds}


def predict_and_format(config, tokenizer, features, per_slot_class_logits,
per_slot_start_logits, per_slot_end_logits,
per_slot_refer_logits, ids, input_ids_unmasked, values,
inform, prefix, ds):
import re

prediction_list = []
dialog_state = ds
for i in range(len(ids)):
if int(ids[i].split('-')[2]) == 0:
dialog_state = {slot: 'none' for slot in config.dst_slot_list}

prediction = {}
prediction_addendum = {}
for slot in config.dst_slot_list:
class_logits = per_slot_class_logits[slot][i]
start_logits = per_slot_start_logits[slot][i]
end_logits = per_slot_end_logits[slot][i]
refer_logits = per_slot_refer_logits[slot][i]

input_ids = features['input_ids'][i].tolist()
class_label_id = int(features['class_label_id'][slot][i])
start_pos = int(features['start_pos'][slot][i])
end_pos = int(features['end_pos'][slot][i])
refer_id = int(features['refer_id'][slot][i])

class_prediction = int(class_logits.argmax())
start_prediction = int(start_logits.argmax())
end_prediction = int(end_logits.argmax())
refer_prediction = int(refer_logits.argmax())

prediction['guid'] = ids[i].split('-')
prediction['class_prediction_%s' % slot] = class_prediction
prediction['class_label_id_%s' % slot] = class_label_id
prediction['start_prediction_%s' % slot] = start_prediction
prediction['start_pos_%s' % slot] = start_pos
prediction['end_prediction_%s' % slot] = end_prediction
prediction['end_pos_%s' % slot] = end_pos
prediction['refer_prediction_%s' % slot] = refer_prediction
prediction['refer_id_%s' % slot] = refer_id
prediction['input_ids_%s' % slot] = input_ids

if class_prediction == config.dst_class_types.index('dontcare'):
dialog_state[slot] = 'dontcare'
elif class_prediction == config.dst_class_types.index(
'copy_value'):
input_tokens = tokenizer.convert_ids_to_tokens(
input_ids_unmasked[i])
dialog_state[slot] = ' '.join(
input_tokens[start_prediction:end_prediction + 1])
dialog_state[slot] = re.sub('(^| )##', '', dialog_state[slot])
elif 'true' in config.dst_class_types and class_prediction == config.dst_class_types.index(
'true'):
dialog_state[slot] = 'true'
elif 'false' in config.dst_class_types and class_prediction == config.dst_class_types.index(
'false'):
dialog_state[slot] = 'false'
elif class_prediction == config.dst_class_types.index('inform'):
# dialog_state[slot] = '§§' + inform[i][slot]
if isinstance(inform[i][slot], str):
dialog_state[slot] = inform[i][slot]
elif isinstance(inform[i][slot], list):
dialog_state[slot] = inform[i][slot][0]
# Referral case is handled below

prediction_addendum['slot_prediction_%s'
% slot] = dialog_state[slot]
prediction_addendum['slot_groundtruth_%s' % slot] = values[i][slot]

# Referral case. All other slot values need to be seen first in order
# to be able to do this correctly.
for slot in config.dst_slot_list:
class_logits = per_slot_class_logits[slot][i]
refer_logits = per_slot_refer_logits[slot][i]

class_prediction = int(class_logits.argmax())
refer_prediction = int(refer_logits.argmax())

if 'refer' in config.dst_class_types and class_prediction == config.dst_class_types.index(
'refer'):
# Only slots that have been mentioned before can be referred to.
# One can think of a situation where one slot is referred to in the same utterance.
# This phenomenon is however currently not properly covered in the training data
# label generation process.
dialog_state[slot] = dialog_state[config.dst_slot_list[
refer_prediction - 1]]
prediction_addendum['slot_prediction_%s' %
slot] = dialog_state[slot] # Value update

prediction.update(prediction_addendum)
prediction_list.append(prediction)

return dialog_state

+ 0
- 1
modelscope/pipelines/nlp/sentiment_classification_pipeline.py View File

@@ -74,5 +74,4 @@ class SentimentClassificationPipeline(Pipeline):
probs = probs[cls_ids].tolist()

cls_names = [self.model.id2label[cid] for cid in cls_ids]

return {OutputKeys.SCORES: probs, OutputKeys.LABELS: cls_names}

+ 0
- 9
modelscope/pipelines/nlp/zero_shot_classification_pipeline.py View File

@@ -29,7 +29,6 @@ class ZeroShotClassificationPipeline(Pipeline):
preprocessor: ZeroShotClassificationPreprocessor = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction

Args:
model (SbertForSentimentClassification): a model instance
preprocessor (SentimentClassificationPreprocessor): a preprocessor instance
@@ -39,10 +38,8 @@ class ZeroShotClassificationPipeline(Pipeline):
model = model if isinstance(
model,
SbertForZeroShotClassification) else Model.from_pretrained(model)

self.entailment_id = 0
self.contradiction_id = 2

if preprocessor is None:
preprocessor = ZeroShotClassificationPreprocessor(model.model_dir)
model.eval()
@@ -51,7 +48,6 @@ class ZeroShotClassificationPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_params = {}
postprocess_params = {}

if 'candidate_labels' in kwargs:
candidate_labels = kwargs.pop('candidate_labels')
preprocess_params['candidate_labels'] = candidate_labels
@@ -60,7 +56,6 @@ class ZeroShotClassificationPipeline(Pipeline):
raise ValueError('You must include at least one label.')
preprocess_params['hypothesis_template'] = kwargs.pop(
'hypothesis_template', '{}')

postprocess_params['multi_label'] = kwargs.pop('multi_label', False)
return preprocess_params, {}, postprocess_params

@@ -74,14 +69,11 @@ class ZeroShotClassificationPipeline(Pipeline):
candidate_labels,
multi_label=False) -> Dict[str, Any]:
"""process the prediction results

Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, Any]: the prediction results
"""

logits = inputs['logits']
if multi_label or len(candidate_labels) == 1:
logits = logits[..., [self.contradiction_id, self.entailment_id]]
@@ -89,7 +81,6 @@ class ZeroShotClassificationPipeline(Pipeline):
else:
logits = logits[..., self.entailment_id]
scores = softmax(logits, axis=-1)

reversed_index = list(reversed(scores.argsort()))
result = {
OutputKeys.LABELS: [candidate_labels[i] for i in reversed_index],


+ 40
- 0
modelscope/pipelines/outputs.py View File

@@ -21,6 +21,7 @@ class OutputKeys(object):
TRANSLATION = 'translation'
RESPONSE = 'response'
PREDICTION = 'prediction'
DIALOG_STATES = 'dialog_states'
VIDEO_EMBEDDING = 'video_embedding'


@@ -158,6 +159,7 @@ TASK_OUTPUTS = {
# }
Tasks.nli: [OutputKeys.SCORES, OutputKeys.LABELS],

# dialog intent prediction result for single sample
# {'pred': array([2.62349960e-03, 4.12110658e-03, 4.12748595e-05, 3.77560973e-05,
# 1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04,
# 6.45841064e-05, 1.15997791e-04, 5.11605394e-05, 9.87020373e-01,
@@ -181,9 +183,47 @@ TASK_OUTPUTS = {
Tasks.dialog_intent_prediction:
[OutputKeys.PREDICTION, OutputKeys.LABEL_POS, OutputKeys.LABEL],

# dialog modeling prediction result for single sample
# sys : ['you', 'are', 'welcome', '.', 'have', 'a', 'great', 'day', '!']
Tasks.dialog_modeling: [OutputKeys.RESPONSE],

# dialog state tracking result for single sample
# {
# "dialog_states": {
# "taxi-leaveAt": "none",
# "taxi-destination": "none",
# "taxi-departure": "none",
# "taxi-arriveBy": "none",
# "restaurant-book_people": "none",
# "restaurant-book_day": "none",
# "restaurant-book_time": "none",
# "restaurant-food": "none",
# "restaurant-pricerange": "none",
# "restaurant-name": "none",
# "restaurant-area": "none",
# "hotel-book_people": "none",
# "hotel-book_day": "none",
# "hotel-book_stay": "none",
# "hotel-name": "none",
# "hotel-area": "none",
# "hotel-parking": "none",
# "hotel-pricerange": "cheap",
# "hotel-stars": "none",
# "hotel-internet": "none",
# "hotel-type": "true",
# "attraction-type": "none",
# "attraction-name": "none",
# "attraction-area": "none",
# "train-book_people": "none",
# "train-leaveAt": "none",
# "train-destination": "none",
# "train-day": "none",
# "train-arriveBy": "none",
# "train-departure": "none"
# }
# }
Tasks.dialog_state_tracking: [OutputKeys.DIALOG_STATES],

# ============ audio tasks ===================

# audio processed for single file in PCM format


+ 1
- 0
modelscope/preprocessors/__init__.py View File

@@ -13,6 +13,7 @@ try:
from .nlp import * # noqa F403
from .space.dialog_intent_prediction_preprocessor import * # noqa F403
from .space.dialog_modeling_preprocessor import * # noqa F403
from .space.dialog_state_tracking_preprocessor import * # noqa F403
except ModuleNotFoundError as e:
if str(e) == "No module named 'tensorflow'":
pass


+ 133
- 0
modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py View File

@@ -0,0 +1,133 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
from typing import Any, Dict

from modelscope.utils.constant import Fields
from modelscope.utils.type_assert import type_assert
from ...metainfo import Preprocessors
from ..base import Preprocessor
from ..builder import PREPROCESSORS
from .dst_processors import convert_examples_to_features, multiwoz22Processor

__all__ = ['DialogStateTrackingPreprocessor']


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.dialog_state_tracking_preprocessor)
class DialogStateTrackingPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
model_dir (str): model path
"""
super().__init__(*args, **kwargs)

from sofa.models.space import SpaceTokenizer, SpaceConfig
self.model_dir: str = model_dir
self.config = SpaceConfig.from_pretrained(self.model_dir)
self.tokenizer = SpaceTokenizer.from_pretrained(self.model_dir)
self.processor = multiwoz22Processor()

@type_assert(object, dict)
def __call__(self, data: Dict) -> Dict[str, Any]:
"""process the raw input data

Args:
data (str): a sentence
Example:
'you are so handsome.'

Returns:
Dict[str, Any]: the preprocessed data
"""
import torch
from torch.utils.data import (DataLoader, RandomSampler,
SequentialSampler)

utter = data['utter']
history_states = data['history_states']
example = self.processor.create_example(
inputs=utter,
history_states=history_states,
set_type='test',
slot_list=self.config.dst_slot_list,
label_maps={},
append_history=True,
use_history_labels=True,
swap_utterances=True,
label_value_repetitions=True,
delexicalize_sys_utts=True,
unk_token='[UNK]',
analyze=False)

features = convert_examples_to_features(
examples=[example],
slot_list=self.config.dst_slot_list,
class_types=self.config.dst_class_types,
model_type=self.config.model_type,
tokenizer=self.tokenizer,
max_seq_length=180, # args.max_seq_length
slot_value_dropout=(0.0))

all_input_ids = torch.tensor([f.input_ids for f in features],
dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features],
dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features],
dtype=torch.long)
all_example_index = torch.arange(
all_input_ids.size(0), dtype=torch.long)
f_start_pos = [f.start_pos for f in features]
f_end_pos = [f.end_pos for f in features]
f_inform_slot_ids = [f.inform_slot for f in features]
f_refer_ids = [f.refer_id for f in features]
f_diag_state = [f.diag_state for f in features]
f_class_label_ids = [f.class_label_id for f in features]
all_start_positions = {}
all_end_positions = {}
all_inform_slot_ids = {}
all_refer_ids = {}
all_diag_state = {}
all_class_label_ids = {}
for s in self.config.dst_slot_list:
all_start_positions[s] = torch.tensor([f[s] for f in f_start_pos],
dtype=torch.long)
all_end_positions[s] = torch.tensor([f[s] for f in f_end_pos],
dtype=torch.long)
all_inform_slot_ids[s] = torch.tensor(
[f[s] for f in f_inform_slot_ids], dtype=torch.long)
all_refer_ids[s] = torch.tensor([f[s] for f in f_refer_ids],
dtype=torch.long)
all_diag_state[s] = torch.tensor([f[s] for f in f_diag_state],
dtype=torch.long)
all_class_label_ids[s] = torch.tensor(
[f[s] for f in f_class_label_ids], dtype=torch.long)
dataset = [
all_input_ids, all_input_mask, all_segment_ids,
all_start_positions, all_end_positions, all_inform_slot_ids,
all_refer_ids, all_diag_state, all_class_label_ids,
all_example_index
]

with torch.no_grad():
diag_state = {
slot:
torch.tensor([0 for _ in range(self.config.eval_batch_size)
]).to(self.config.device)
for slot in self.config.dst_slot_list
}

if len(history_states) > 2:
ds = history_states[-2]
else:
ds = {slot: 'none' for slot in self.config.dst_slot_list}

return {
'batch': dataset,
'features': features,
'diag_state': diag_state,
'ds': ds
}

+ 1441
- 0
modelscope/preprocessors/space/dst_processors.py
File diff suppressed because it is too large
View File


+ 59
- 0
modelscope/preprocessors/space/tensorlistdataset.py View File

@@ -0,0 +1,59 @@
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torch.utils.data import Dataset


class TensorListDataset(Dataset):
r"""Dataset wrapping tensors, tensor dicts and tensor lists.

Arguments:
*data (Tensor or dict or list of Tensors): tensors that have the same size
of the first dimension.
"""

def __init__(self, *data):
if isinstance(data[0], dict):
size = list(data[0].values())[0].size(0)
elif isinstance(data[0], list):
size = data[0][0].size(0)
else:
size = data[0].size(0)
for element in data:
if isinstance(element, dict):
assert all(
size == tensor.size(0)
for name, tensor in element.items()) # dict of tensors
elif isinstance(element, list):
assert all(size == tensor.size(0)
for tensor in element) # list of tensors
else:
assert size == element.size(0) # tensor
self.size = size
self.data = data

def __getitem__(self, index):
result = []
for element in self.data:
if isinstance(element, dict):
result.append({k: v[index] for k, v in element.items()})
elif isinstance(element, list):
result.append(v[index] for v in element)
else:
result.append(element[index])
return tuple(result)

def __len__(self):
return self.size

+ 1
- 0
modelscope/utils/constant.py View File

@@ -48,6 +48,7 @@ class Tasks(object):
text_generation = 'text-generation'
dialog_modeling = 'dialog-modeling'
dialog_intent_prediction = 'dialog-intent-prediction'
dialog_state_tracking = 'dialog-state-tracking'
table_question_answering = 'table-question-answering'
feature_extraction = 'feature-extraction'
fill_mask = 'fill-mask'


+ 10
- 0
modelscope/utils/nlp/space/utils_dst.py View File

@@ -0,0 +1,10 @@
def batch_to_device(batch, device):
batch_on_device = []
for element in batch:
if isinstance(element, dict):
batch_on_device.append(
{k: v.to(device)
for k, v in element.items()})
else:
batch_on_device.append(element.to(device))
return tuple(batch_on_device)

+ 1
- 1
requirements/nlp.txt View File

@@ -1,3 +1,3 @@
https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz
http://ait-public.oss-cn-hangzhou-zmf.aliyuncs.com/jizhu/en_core_web_sm-2.3.1.tar.gz
sofa==1.0.5
spacy>=2.3.5

+ 15
- 1
tests/pipelines/test_dialog_intent_prediction.py View File

@@ -18,7 +18,7 @@ class DialogIntentPredictionTest(unittest.TestCase):
]

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run(self):
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)
preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path)
model = SpaceForDialogIntent(
@@ -56,6 +56,20 @@ class DialogIntentPredictionTest(unittest.TestCase):
for my_pipeline, item in list(zip(pipelines, self.test_case)):
print(my_pipeline(item))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
pipelines = [
pipeline(task=Tasks.dialog_intent_prediction, model=self.model_id)
]
for my_pipeline, item in list(zip(pipelines, self.test_case)):
print(my_pipeline(item))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipelines = [pipeline(task=Tasks.dialog_intent_prediction)]
for my_pipeline, item in list(zip(pipelines, self.test_case)):
print(my_pipeline(item))


if __name__ == '__main__':
unittest.main()

+ 33
- 21
tests/pipelines/test_dialog_modeling.py View File

@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
from typing import List

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
@@ -89,8 +90,22 @@ class DialogModelingTest(unittest.TestCase):
}
}

def generate_and_print_dialog_response(
self, pipelines: List[DialogModelingPipeline]):

result = {}
for step, item in enumerate(self.test_case['sng0073']['log']):
user = item['user']
print('user: {}'.format(user))

result = pipelines[step % 2]({
'user_input': user,
'history': result
})
print('response : {}'.format(result['response']))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run(self):
def test_run_by_direct_model_download(self):

cache_path = snapshot_download(self.model_id)

@@ -106,17 +121,7 @@ class DialogModelingTest(unittest.TestCase):
model=model,
preprocessor=preprocessor)
]

result = {}
for step, item in enumerate(self.test_case['sng0073']['log']):
user = item['user']
print('user: {}'.format(user))

result = pipelines[step % 2]({
'user_input': user,
'history': result
})
print('response : {}'.format(result['response']))
self.generate_and_print_dialog_response(pipelines)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
@@ -131,16 +136,23 @@ class DialogModelingTest(unittest.TestCase):
preprocessor=preprocessor)
]

result = {}
for step, item in enumerate(self.test_case['sng0073']['log']):
user = item['user']
print('user: {}'.format(user))
self.generate_and_print_dialog_response(pipelines)

result = pipelines[step % 2]({
'user_input': user,
'history': result
})
print('response : {}'.format(result['response']))
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
pipelines = [
pipeline(task=Tasks.dialog_modeling, model=self.model_id),
pipeline(task=Tasks.dialog_modeling, model=self.model_id)
]
self.generate_and_print_dialog_response(pipelines)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipelines = [
pipeline(task=Tasks.dialog_modeling),
pipeline(task=Tasks.dialog_modeling)
]
self.generate_and_print_dialog_response(pipelines)


if __name__ == '__main__':


+ 143
- 0
tests/pipelines/test_dialog_state_tracking.py View File

@@ -0,0 +1,143 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
from typing import List

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model, SpaceForDialogStateTracking
from modelscope.pipelines import DialogStateTrackingPipeline, pipeline
from modelscope.preprocessors import DialogStateTrackingPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class DialogStateTrackingTest(unittest.TestCase):
model_id = 'damo/nlp_space_dialog-state-tracking'
test_case = [{
'User-1':
'Hi, I\'m looking for a train that is going to cambridge and arriving there by 20:45, '
'is there anything like that?'
}, {
'System-1':
'There are over 1,000 trains like that. Where will you be departing from?',
'Dialog_Act-1': {
'Train-Inform': [['Choice', 'over 1'], ['Choice', '000']],
'Train-Request': [['Depart', '?']]
},
'User-2': 'I am departing from birmingham new street.'
}, {
'System-2': 'Can you confirm your desired travel day?',
'Dialog_Act-2': {
'Train-Request': [['Day', '?']]
},
'User-3': 'I would like to leave on wednesday'
}, {
'System-3':
'I show a train leaving birmingham new street at 17:40 and arriving at 20:23 on Wednesday. '
'Will this work for you?',
'Dialog_Act-3': {
'Train-Inform': [['Arrive', '20:23'], ['Leave', '17:40'],
['Day', 'Wednesday'],
['Depart', 'birmingham new street']]
},
'User-4':
'That will, yes. Please make a booking for 5 people please.',
}, {
'System-4':
'I\'ve booked your train tickets, and your reference number is A9NHSO9Y.',
'Dialog_Act-4': {
'Train-OfferBooked': [['Ref', 'A9NHSO9Y']]
},
'User-5':
'Thanks so much. I would also need a place to say. '
'I am looking for something with 4 stars and has free wifi.'
}, {
'System-5':
'How about the cambridge belfry? '
'It has all the attributes you requested and a great name! '
'Maybe even a real belfry?',
'Dialog_Act-5': {
'Hotel-Recommend': [['Name', 'the cambridge belfry']]
},
'User-6':
'That sounds great, could you make a booking for me please?',
}, {
'System-6':
'What day would you like your booking for?',
'Dialog_Act-6': {
'Booking-Request': [['Day', '?']]
},
'User-7':
'Please book it for Wednesday for 5 people and 5 nights, please.',
}, {
'System-7': 'Booking was successful. Reference number is : 5NAWGJDC.',
'Dialog_Act-7': {
'Booking-Book': [['Ref', '5NAWGJDC']]
},
'User-8': 'Thank you, goodbye',
}]

def tracking_and_print_dialog_states(
self, pipelines: List[DialogStateTrackingPipeline]):
import json
pipelines_len = len(pipelines)
history_states = [{}]
utter = {}
for step, item in enumerate(self.test_case):
utter.update(item)
result = pipelines[step % pipelines_len]({
'utter':
utter,
'history_states':
history_states
})
print(json.dumps(result))

history_states.extend([result['dialog_states'], {}])

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)

model = SpaceForDialogStateTracking(cache_path)
preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path)
pipelines = [
DialogStateTrackingPipeline(
model=model, preprocessor=preprocessor),
pipeline(
task=Tasks.dialog_state_tracking,
model=model,
preprocessor=preprocessor)
]
self.tracking_and_print_dialog_states(pipelines)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
preprocessor = DialogStateTrackingPreprocessor(
model_dir=model.model_dir)
pipelines = [
DialogStateTrackingPipeline(
model=model, preprocessor=preprocessor),
pipeline(
task=Tasks.dialog_state_tracking,
model=model,
preprocessor=preprocessor)
]

self.tracking_and_print_dialog_states(pipelines)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
pipelines = [
pipeline(task=Tasks.dialog_state_tracking, model=self.model_id)
]
self.tracking_and_print_dialog_states(pipelines)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipelines = [pipeline(task=Tasks.dialog_state_tracking)]
self.tracking_and_print_dialog_states(pipelines)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save