Browse Source

formating output

master
智丞 3 years ago
parent
commit
7b3d792943
2 changed files with 7 additions and 34 deletions
  1. +2
    -1
      modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py
  2. +5
    -33
      modelscope/pipelines/outputs.py

+ 2
- 1
modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py View File

@@ -6,6 +6,7 @@ from ...preprocessors import DialogStateTrackingPreprocessor
from ...utils.constant import Tasks from ...utils.constant import Tasks
from ..base import Pipeline from ..base import Pipeline
from ..builder import PIPELINES from ..builder import PIPELINES
from ..outputs import OutputKeys


__all__ = ['DialogStateTrackingPipeline'] __all__ = ['DialogStateTrackingPipeline']


@@ -53,7 +54,7 @@ class DialogStateTrackingPipeline(Pipeline):
_outputs[5], unique_ids, input_ids_unmasked, _outputs[5], unique_ids, input_ids_unmasked,
values, inform, prefix, ds) values, inform, prefix, ds)


return {'dialog_states': ds}
return {OutputKeys.DIALOG_STATES: ds}




def predict_and_format(config, tokenizer, features, per_slot_class_logits, def predict_and_format(config, tokenizer, features, per_slot_class_logits,


+ 5
- 33
modelscope/pipelines/outputs.py View File

@@ -20,6 +20,7 @@ class OutputKeys(object):
TEXT_EMBEDDING = 'text_embedding' TEXT_EMBEDDING = 'text_embedding'
RESPONSE = 'response' RESPONSE = 'response'
PREDICTION = 'prediction' PREDICTION = 'prediction'
DIALOG_STATES = 'dialog_states'




TASK_OUTPUTS = { TASK_OUTPUTS = {
@@ -151,6 +152,7 @@ TASK_OUTPUTS = {
# } # }
Tasks.nli: [OutputKeys.SCORES, OutputKeys.LABELS], Tasks.nli: [OutputKeys.SCORES, OutputKeys.LABELS],


# dialog intent prediction result for single sample
# {'pred': array([2.62349960e-03, 4.12110658e-03, 4.12748595e-05, 3.77560973e-05, # {'pred': array([2.62349960e-03, 4.12110658e-03, 4.12748595e-05, 3.77560973e-05,
# 1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04, # 1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04,
# 6.45841064e-05, 1.15997791e-04, 5.11605394e-05, 9.87020373e-01, # 6.45841064e-05, 1.15997791e-04, 5.11605394e-05, 9.87020373e-01,
@@ -174,16 +176,11 @@ TASK_OUTPUTS = {
Tasks.dialog_intent_prediction: Tasks.dialog_intent_prediction:
[OutputKeys.PREDICTION, OutputKeys.LABEL_POS, OutputKeys.LABEL], [OutputKeys.PREDICTION, OutputKeys.LABEL_POS, OutputKeys.LABEL],


# dialog modeling prediction result for single sample
# sys : ['you', 'are', 'welcome', '.', 'have', 'a', 'great', 'day', '!'] # sys : ['you', 'are', 'welcome', '.', 'have', 'a', 'great', 'day', '!']
Tasks.dialog_modeling: [OutputKeys.RESPONSE], Tasks.dialog_modeling: [OutputKeys.RESPONSE],


# nli result for single sample
# {
# "labels": ["happy", "sad", "calm", "angry"],
# "scores": [0.9, 0.1, 0.05, 0.05]
# }
Tasks.nli: ['scores', 'labels'],

# dialog state tracking result for single sample
# { # {
# "dialog_states": { # "dialog_states": {
# "taxi-leaveAt": "none", # "taxi-leaveAt": "none",
@@ -218,32 +215,7 @@ TASK_OUTPUTS = {
# "train-departure": "none" # "train-departure": "none"
# } # }
# } # }
Tasks.dialog_state_tracking: ['dialog_states'],

# {'pred': array([2.62349960e-03, 4.12110658e-03, 4.12748595e-05, 3.77560973e-05,
# 1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04,
# 6.45841064e-05, 1.15997791e-04, 5.11605394e-05, 9.87020373e-01,
# 2.66957268e-05, 4.72324500e-05, 9.74208378e-05, 4.18022355e-05,
# 2.97343540e-05, 5.81317654e-05, 5.44203431e-05, 6.28319322e-05,
# 7.34537680e-05, 6.61411541e-05, 3.62534920e-05, 8.58885178e-05,
# 8.24327726e-05, 4.66077945e-05, 5.32869453e-05, 4.16190960e-05,
# 5.97518992e-05, 3.92273068e-05, 3.44069012e-05, 9.92335918e-05,
# 9.25978165e-05, 6.26462061e-05, 3.32317031e-05, 1.32061413e-03,
# 2.01607945e-05, 3.36636294e-05, 3.99156743e-05, 5.84108493e-05,
# 2.53432900e-05, 4.95731190e-04, 2.64443643e-05, 4.46992999e-05,
# 2.42672231e-05, 4.75615161e-05, 2.66230145e-05, 4.00083954e-05,
# 2.90536875e-04, 4.23891543e-05, 8.63691166e-05, 4.98188965e-05,
# 3.47019341e-05, 4.52718523e-05, 4.20905781e-05, 5.50173208e-05,
# 4.92360487e-05, 3.56021264e-05, 2.13957210e-05, 6.17428886e-05,
# 1.43893281e-04, 7.32152112e-05, 2.91354867e-04, 2.46623786e-05,
# 3.61441926e-05, 3.38475402e-05, 3.44323053e-05, 5.70138109e-05,
# 4.31488479e-05, 4.94503947e-05, 4.30105974e-05, 1.00963116e-04,
# 2.82062047e-05, 1.15582036e-04, 4.48261271e-05, 3.99339879e-05,
# 7.27692823e-05], dtype=float32), 'label_pos': array([11]), 'label': 'lost_or_stolen_card'}
Tasks.dialog_intent_prediction: ['prediction', 'label_pos', 'label'],

# sys : ['you', 'are', 'welcome', '.', 'have', 'a', 'great', 'day', '!']
Tasks.dialog_modeling: ['response'],
Tasks.dialog_state_tracking: [OutputKeys.DIALOG_STATES],


# ============ audio tasks =================== # ============ audio tasks ===================




Loading…
Cancel
Save