Browse Source

[to #42322933] modify space task name

master
ly119399 3 years ago
parent
commit
16139cefb6
13 changed files with 104 additions and 114 deletions
  1. +4
    -2
      modelscope/metainfo.py
  2. +1
    -1
      modelscope/models/nlp/space/space_for_dialog_intent_prediction.py
  3. +5
    -3
      modelscope/models/nlp/space/space_for_dialog_modeling.py
  4. +2
    -1
      modelscope/models/nlp/space/space_for_dialog_state_tracking.py
  5. +37
    -38
      modelscope/outputs.py
  6. +1
    -6
      modelscope/pipelines/builder.py
  7. +2
    -3
      modelscope/pipelines/nlp/__init__.py
  8. +7
    -7
      modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py
  9. +4
    -5
      modelscope/pipelines/nlp/dialog_modeling_pipeline.py
  10. +3
    -2
      modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py
  11. +8
    -11
      tests/pipelines/test_dialog_intent_prediction.py
  12. +19
    -24
      tests/pipelines/test_dialog_modeling.py
  13. +11
    -11
      tests/pipelines/test_dialog_state_tracking.py

+ 4
- 2
modelscope/metainfo.py View File

@@ -26,7 +26,9 @@ class Models(object):
structbert = 'structbert'
veco = 'veco'
translation = 'csanmt-translation'
space = 'space'
space_dst = 'space-dst'
space_intent = 'space-intent'
space_modeling = 'space-modeling'
tcrf = 'transformer-crf'
bart = 'bart'
gpt3 = 'gpt3'
@@ -116,7 +118,7 @@ class Pipelines(object):
csanmt_translation = 'csanmt-translation'
nli = 'nli'
dialog_intent_prediction = 'dialog-intent-prediction'
task_oriented_conversation = 'task-oriented-conversation'
dialog_modeling = 'dialog-modeling'
dialog_state_tracking = 'dialog-state-tracking'
zero_shot_classification = 'zero-shot-classification'
text_error_correction = 'text-error-correction'


+ 1
- 1
modelscope/models/nlp/space/space_for_dialog_intent_prediction.py View File

@@ -16,7 +16,7 @@ __all__ = ['SpaceForDialogIntent']


@MODELS.register_module(
Tasks.dialog_intent_prediction, module_name=Models.space)
Tasks.task_oriented_conversation, module_name=Models.space_intent)
class SpaceForDialogIntent(TorchModel):

def __init__(self, model_dir: str, *args, **kwargs):


+ 5
- 3
modelscope/models/nlp/space/space_for_dialog_modeling.py View File

@@ -16,7 +16,7 @@ __all__ = ['SpaceForDialogModeling']


@MODELS.register_module(
Tasks.task_oriented_conversation, module_name=Models.space)
Tasks.task_oriented_conversation, module_name=Models.space_modeling)
class SpaceForDialogModeling(TorchModel):

def __init__(self, model_dir: str, *args, **kwargs):
@@ -34,8 +34,10 @@ class SpaceForDialogModeling(TorchModel):
Config.from_file(
os.path.join(self.model_dir, ModelFile.CONFIGURATION)))

self.config.use_gpu = True if 'device' not in kwargs or kwargs[
'device'] == 'gpu' else False
import torch
self.config.use_gpu = True if (
'device' not in kwargs or kwargs['device']
== 'gpu') and torch.cuda.is_available() else False

self.text_field = kwargs.pop(
'text_field',


+ 2
- 1
modelscope/models/nlp/space/space_for_dialog_state_tracking.py View File

@@ -9,7 +9,8 @@ from modelscope.utils.constant import Tasks
__all__ = ['SpaceForDialogStateTracking']


@MODELS.register_module(Tasks.dialog_state_tracking, module_name=Models.space)
@MODELS.register_module(
Tasks.task_oriented_conversation, module_name=Models.space_dst)
class SpaceForDialogStateTracking(TorchModel):

def __init__(self, model_dir: str, *args, **kwargs):


+ 37
- 38
modelscope/outputs.py View File

@@ -320,7 +320,7 @@ TASK_OUTPUTS = {
Tasks.fill_mask: [OutputKeys.TEXT],

# (Deprecated) dialog intent prediction result for single sample
# {'pred': array([2.62349960e-03, 4.12110658e-03, 4.12748595e-05, 3.77560973e-05,
# {'output': {'prediction': array([2.62349960e-03, 4.12110658e-03, 4.12748595e-05, 3.77560973e-05,
# 1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04,
# 6.45841064e-05, 1.15997791e-04, 5.11605394e-05, 9.87020373e-01,
# 2.66957268e-05, 4.72324500e-05, 9.74208378e-05, 4.18022355e-05,
@@ -339,50 +339,49 @@ TASK_OUTPUTS = {
# 3.61441926e-05, 3.38475402e-05, 3.44323053e-05, 5.70138109e-05,
# 4.31488479e-05, 4.94503947e-05, 4.30105974e-05, 1.00963116e-04,
# 2.82062047e-05, 1.15582036e-04, 4.48261271e-05, 3.99339879e-05,
# 7.27692823e-05], dtype=float32), 'label_pos': array([11]), 'label': 'lost_or_stolen_card'}
Tasks.dialog_intent_prediction:
[OutputKeys.PREDICTION, OutputKeys.LABEL_POS, OutputKeys.LABEL],
# 7.27692823e-05], dtype=float32), 'label_pos': array([11]), 'label': 'lost_or_stolen_card'}}

# (Deprecated) dialog modeling prediction result for single sample
# sys : ['you', 'are', 'welcome', '.', 'have', 'a', 'great', 'day', '!']
Tasks.task_oriented_conversation: [OutputKeys.RESPONSE],
# {'output' : ['you', 'are', 'welcome', '.', 'have', 'a', 'great', 'day', '!']}

# (Deprecated) dialog state tracking result for single sample
# {
# "dialog_states": {
# "taxi-leaveAt": "none",
# "taxi-destination": "none",
# "taxi-departure": "none",
# "taxi-arriveBy": "none",
# "restaurant-book_people": "none",
# "restaurant-book_day": "none",
# "restaurant-book_time": "none",
# "restaurant-food": "none",
# "restaurant-pricerange": "none",
# "restaurant-name": "none",
# "restaurant-area": "none",
# "hotel-book_people": "none",
# "hotel-book_day": "none",
# "hotel-book_stay": "none",
# "hotel-name": "none",
# "hotel-area": "none",
# "hotel-parking": "none",
# "hotel-pricerange": "cheap",
# "hotel-stars": "none",
# "hotel-internet": "none",
# "hotel-type": "true",
# "attraction-type": "none",
# "attraction-name": "none",
# "attraction-area": "none",
# "train-book_people": "none",
# "train-leaveAt": "none",
# "train-destination": "none",
# "train-day": "none",
# "train-arriveBy": "none",
# "train-departure": "none"
# "output":{
# "dialog_states": {
# "taxi-leaveAt": "none",
# "taxi-destination": "none",
# "taxi-departure": "none",
# "taxi-arriveBy": "none",
# "restaurant-book_people": "none",
# "restaurant-book_day": "none",
# "restaurant-book_time": "none",
# "restaurant-food": "none",
# "restaurant-pricerange": "none",
# "restaurant-name": "none",
# "restaurant-area": "none",
# "hotel-book_people": "none",
# "hotel-book_day": "none",
# "hotel-book_stay": "none",
# "hotel-name": "none",
# "hotel-area": "none",
# "hotel-parking": "none",
# "hotel-pricerange": "cheap",
# "hotel-stars": "none",
# "hotel-internet": "none",
# "hotel-type": "true",
# "attraction-type": "none",
# "attraction-name": "none",
# "attraction-area": "none",
# "train-book_people": "none",
# "train-leaveAt": "none",
# "train-destination": "none",
# "train-day": "none",
# "train-arriveBy": "none",
# "train-departure": "none"
# }
# }
# }
Tasks.dialog_state_tracking: [OutputKeys.DIALOG_STATES],
Tasks.task_oriented_conversation: [OutputKeys.OUTPUT],

# ============ audio tasks ===================
# asr result for single sample


+ 1
- 6
modelscope/pipelines/builder.py View File

@@ -48,13 +48,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.zero_shot_classification:
(Pipelines.zero_shot_classification,
'damo/nlp_structbert_zero-shot-classification_chinese-base'),
Tasks.dialog_intent_prediction:
(Pipelines.dialog_intent_prediction,
'damo/nlp_space_dialog-intent-prediction'),
Tasks.task_oriented_conversation: (Pipelines.task_oriented_conversation,
Tasks.task_oriented_conversation: (Pipelines.dialog_modeling,
'damo/nlp_space_dialog-modeling'),
Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking,
'damo/nlp_space_dialog-state-tracking'),
Tasks.text_error_correction:
(Pipelines.text_error_correction,
'damo/nlp_bart_text-error-correction_chinese'),


+ 2
- 3
modelscope/pipelines/nlp/__init__.py View File

@@ -5,7 +5,7 @@ from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .dialog_intent_prediction_pipeline import DialogIntentPredictionPipeline
from .task_oriented_conversation_pipeline import TaskOrientedConversationPipeline
from .dialog_modeling_pipeline import DialogModelingPipeline
from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline
from .fill_mask_pipeline import FillMaskPipeline
from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline
@@ -24,8 +24,7 @@ else:
_import_structure = {
'dialog_intent_prediction_pipeline':
['DialogIntentPredictionPipeline'],
'task_oriented_conversation_pipeline':
['TaskOrientedConversationPipeline'],
'dialog_modeling_pipeline': ['DialogModelingPipeline'],
'dialog_state_tracking_pipeline': ['DialogStateTrackingPipeline'],
'fill_mask_pipeline': ['FillMaskPipeline'],
'single_sentence_classification_pipeline':


+ 7
- 7
modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py View File

@@ -15,7 +15,7 @@ __all__ = ['DialogIntentPredictionPipeline']


@PIPELINES.register_module(
Tasks.dialog_intent_prediction,
Tasks.task_oriented_conversation,
module_name=Pipelines.dialog_intent_prediction)
class DialogIntentPredictionPipeline(Pipeline):

@@ -51,10 +51,10 @@ class DialogIntentPredictionPipeline(Pipeline):
pred = inputs['pred']
pos = np.where(pred == np.max(pred))

result = {
OutputKeys.PREDICTION: pred,
OutputKeys.LABEL_POS: pos[0],
OutputKeys.LABEL: self.categories[pos[0][0]]
return {
OutputKeys.OUTPUT: {
OutputKeys.PREDICTION: pred,
OutputKeys.LABEL_POS: pos[0],
OutputKeys.LABEL: self.categories[pos[0][0]]
}
}

return result

modelscope/pipelines/nlp/task_oriented_conversation_pipeline.py → modelscope/pipelines/nlp/dialog_modeling_pipeline.py View File

@@ -11,13 +11,12 @@ from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import DialogModelingPreprocessor
from modelscope.utils.constant import Tasks

__all__ = ['TaskOrientedConversationPipeline']
__all__ = ['DialogModelingPipeline']


@PIPELINES.register_module(
Tasks.task_oriented_conversation,
module_name=Pipelines.task_oriented_conversation)
class TaskOrientedConversationPipeline(Pipeline):
Tasks.task_oriented_conversation, module_name=Pipelines.dialog_modeling)
class DialogModelingPipeline(Pipeline):

def __init__(self,
model: Union[SpaceForDialogModeling, str],
@@ -51,6 +50,6 @@ class TaskOrientedConversationPipeline(Pipeline):
inputs['resp'])
assert len(sys_rsp) > 2
sys_rsp = sys_rsp[1:len(sys_rsp) - 1]
inputs[OutputKeys.RESPONSE] = sys_rsp
inputs[OutputKeys.OUTPUT] = sys_rsp

return inputs

+ 3
- 2
modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py View File

@@ -13,7 +13,8 @@ __all__ = ['DialogStateTrackingPipeline']


@PIPELINES.register_module(
Tasks.dialog_state_tracking, module_name=Pipelines.dialog_state_tracking)
Tasks.task_oriented_conversation,
module_name=Pipelines.dialog_state_tracking)
class DialogStateTrackingPipeline(Pipeline):

def __init__(self,
@@ -63,7 +64,7 @@ class DialogStateTrackingPipeline(Pipeline):
_outputs[5], unique_ids, input_ids_unmasked,
values, inform, prefix, ds)

return {OutputKeys.DIALOG_STATES: ds}
return {OutputKeys.OUTPUT: ds}


def predict_and_format(config, tokenizer, features, per_slot_class_logits,


+ 8
- 11
tests/pipelines/test_dialog_intent_prediction.py View File

@@ -20,7 +20,7 @@ class DialogIntentPredictionTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)
cache_path = snapshot_download(self.model_id, revision='update')
preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path)
model = SpaceForDialogIntent(
model_dir=cache_path,
@@ -31,7 +31,7 @@ class DialogIntentPredictionTest(unittest.TestCase):
DialogIntentPredictionPipeline(
model=model, preprocessor=preprocessor),
pipeline(
task=Tasks.dialog_intent_prediction,
task=Tasks.task_oriented_conversation,
model=model,
preprocessor=preprocessor)
]
@@ -41,7 +41,7 @@ class DialogIntentPredictionTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
model = Model.from_pretrained(self.model_id, revision='update')
preprocessor = DialogIntentPredictionPreprocessor(
model_dir=model.model_dir)

@@ -49,7 +49,7 @@ class DialogIntentPredictionTest(unittest.TestCase):
DialogIntentPredictionPipeline(
model=model, preprocessor=preprocessor),
pipeline(
task=Tasks.dialog_intent_prediction,
task=Tasks.task_oriented_conversation,
model=model,
preprocessor=preprocessor)
]
@@ -60,17 +60,14 @@ class DialogIntentPredictionTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
pipelines = [
pipeline(task=Tasks.dialog_intent_prediction, model=self.model_id)
pipeline(
task=Tasks.task_oriented_conversation,
model=self.model_id,
model_revision='update')
]
for my_pipeline, item in list(zip(pipelines, self.test_case)):
print(my_pipeline(item))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipelines = [pipeline(task=Tasks.dialog_intent_prediction)]
for my_pipeline, item in list(zip(pipelines, self.test_case)):
print(my_pipeline(item))


if __name__ == '__main__':
unittest.main()

tests/pipelines/test_task_oriented_conversation.py → tests/pipelines/test_dialog_modeling.py View File

@@ -5,14 +5,15 @@ from typing import List
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import SpaceForDialogModeling
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import TaskOrientedConversationPipeline
from modelscope.pipelines.nlp import DialogModelingPipeline
from modelscope.preprocessors import DialogModelingPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class TaskOrientedConversationTest(unittest.TestCase):
class DialogModelingTest(unittest.TestCase):
model_id = 'damo/nlp_space_dialog-modeling'
test_case = {
'sng0073': {
@@ -92,23 +93,25 @@ class TaskOrientedConversationTest(unittest.TestCase):
}

def generate_and_print_dialog_response(
self, pipelines: List[TaskOrientedConversationPipeline]):
self, pipelines: List[DialogModelingPipeline]):

result = {}
pipeline_len = len(pipelines)
for step, item in enumerate(self.test_case['sng0073']['log']):
user = item['user']
print('user: {}'.format(user))

result = pipelines[step % 2]({
result = pipelines[step % pipeline_len]({
'user_input': user,
'history': result
})
print('response : {}'.format(result['response']))
print('response : {}'.format(result[OutputKeys.OUTPUT]))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):

cache_path = snapshot_download(self.model_id)
cache_path = snapshot_download(
self.model_id, revision='task_oriented_conversation')

preprocessor = DialogModelingPreprocessor(model_dir=cache_path)
model = SpaceForDialogModeling(
@@ -116,27 +119,18 @@ class TaskOrientedConversationTest(unittest.TestCase):
text_field=preprocessor.text_field,
config=preprocessor.config)
pipelines = [
TaskOrientedConversationPipeline(
model=model, preprocessor=preprocessor),
pipeline(
task=Tasks.task_oriented_conversation,
model=model,
preprocessor=preprocessor)
DialogModelingPipeline(model=model, preprocessor=preprocessor)
]
self.generate_and_print_dialog_response(pipelines)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
model = Model.from_pretrained(
self.model_id, revision='task_oriented_conversation')
preprocessor = DialogModelingPreprocessor(model_dir=model.model_dir)

pipelines = [
TaskOrientedConversationPipeline(
model=model, preprocessor=preprocessor),
pipeline(
task=Tasks.task_oriented_conversation,
model=model,
preprocessor=preprocessor)
DialogModelingPipeline(model=model, preprocessor=preprocessor)
]

self.generate_and_print_dialog_response(pipelines)
@@ -145,17 +139,18 @@ class TaskOrientedConversationTest(unittest.TestCase):
def test_run_with_model_name(self):
pipelines = [
pipeline(
task=Tasks.task_oriented_conversation, model=self.model_id),
pipeline(
task=Tasks.task_oriented_conversation, model=self.model_id)
task=Tasks.task_oriented_conversation,
model=self.model_id,
model_revision='task_oriented_conversation')
]
self.generate_and_print_dialog_response(pipelines)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipelines = [
pipeline(task=Tasks.task_oriented_conversation),
pipeline(task=Tasks.task_oriented_conversation)
pipeline(
task=Tasks.task_oriented_conversation,
model_revision='task_oriented_conversation')
]
self.generate_and_print_dialog_response(pipelines)


+ 11
- 11
tests/pipelines/test_dialog_state_tracking.py View File

@@ -5,6 +5,7 @@ from typing import List
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import SpaceForDialogStateTracking
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import DialogStateTrackingPipeline
from modelscope.preprocessors import DialogStateTrackingPreprocessor
@@ -94,11 +95,11 @@ class DialogStateTrackingTest(unittest.TestCase):
})
print(json.dumps(result))

history_states.extend([result['dialog_states'], {}])
history_states.extend([result[OutputKeys.OUTPUT], {}])

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)
cache_path = snapshot_download(self.model_id, revision='update')

model = SpaceForDialogStateTracking(cache_path)
preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path)
@@ -106,7 +107,7 @@ class DialogStateTrackingTest(unittest.TestCase):
DialogStateTrackingPipeline(
model=model, preprocessor=preprocessor),
pipeline(
task=Tasks.dialog_state_tracking,
task=Tasks.task_oriented_conversation,
model=model,
preprocessor=preprocessor)
]
@@ -114,14 +115,15 @@ class DialogStateTrackingTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
model = Model.from_pretrained(self.model_id, revision='update')

preprocessor = DialogStateTrackingPreprocessor(
model_dir=model.model_dir)
pipelines = [
DialogStateTrackingPipeline(
model=model, preprocessor=preprocessor),
pipeline(
task=Tasks.dialog_state_tracking,
task=Tasks.task_oriented_conversation,
model=model,
preprocessor=preprocessor)
]
@@ -131,15 +133,13 @@ class DialogStateTrackingTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
pipelines = [
pipeline(task=Tasks.dialog_state_tracking, model=self.model_id)
pipeline(
task=Tasks.task_oriented_conversation,
model=self.model_id,
model_revision='update')
]
self.tracking_and_print_dialog_states(pipelines)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipelines = [pipeline(task=Tasks.dialog_state_tracking)]
self.tracking_and_print_dialog_states(pipelines)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save