Browse Source

add outputs

master
ly119399 3 years ago
parent
commit
88e6766944
3 changed files with 38 additions and 1 deletions
  1. +6
    -1
      modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py
  2. +25
    -0
      modelscope/pipelines/outputs.py
  3. +7
    -0
      modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py

+ 6
- 1
modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py View File

@@ -28,6 +28,7 @@ class DialogIntentPredictionPipeline(Pipeline):


super().__init__(model=model, preprocessor=preprocessor, **kwargs) super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.model = model self.model = model
self.categories = preprocessor.categories


def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""process the prediction results """process the prediction results
@@ -42,6 +43,10 @@ class DialogIntentPredictionPipeline(Pipeline):
pred = inputs['pred'] pred = inputs['pred']
pos = np.where(pred == np.max(pred)) pos = np.where(pred == np.max(pred))


result = {'pred': pred, 'label': pos[0]}
result = {
'pred': pred,
'label_pos': pos[0],
'label': self.categories[pos[0][0]]
}


return result return result

+ 25
- 0
modelscope/pipelines/outputs.py View File

@@ -122,6 +122,31 @@ TASK_OUTPUTS = {
# } # }
Tasks.nli: ['scores', 'labels'], Tasks.nli: ['scores', 'labels'],


# {'pred': array([2.62349960e-03, 4.12110658e-03, 4.12748595e-05, 3.77560973e-05,
# 1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04,
# 6.45841064e-05, 1.15997791e-04, 5.11605394e-05, 9.87020373e-01,
# 2.66957268e-05, 4.72324500e-05, 9.74208378e-05, 4.18022355e-05,
# 2.97343540e-05, 5.81317654e-05, 5.44203431e-05, 6.28319322e-05,
# 7.34537680e-05, 6.61411541e-05, 3.62534920e-05, 8.58885178e-05,
# 8.24327726e-05, 4.66077945e-05, 5.32869453e-05, 4.16190960e-05,
# 5.97518992e-05, 3.92273068e-05, 3.44069012e-05, 9.92335918e-05,
# 9.25978165e-05, 6.26462061e-05, 3.32317031e-05, 1.32061413e-03,
# 2.01607945e-05, 3.36636294e-05, 3.99156743e-05, 5.84108493e-05,
# 2.53432900e-05, 4.95731190e-04, 2.64443643e-05, 4.46992999e-05,
# 2.42672231e-05, 4.75615161e-05, 2.66230145e-05, 4.00083954e-05,
# 2.90536875e-04, 4.23891543e-05, 8.63691166e-05, 4.98188965e-05,
# 3.47019341e-05, 4.52718523e-05, 4.20905781e-05, 5.50173208e-05,
# 4.92360487e-05, 3.56021264e-05, 2.13957210e-05, 6.17428886e-05,
# 1.43893281e-04, 7.32152112e-05, 2.91354867e-04, 2.46623786e-05,
# 3.61441926e-05, 3.38475402e-05, 3.44323053e-05, 5.70138109e-05,
# 4.31488479e-05, 4.94503947e-05, 4.30105974e-05, 1.00963116e-04,
# 2.82062047e-05, 1.15582036e-04, 4.48261271e-05, 3.99339879e-05,
# 7.27692823e-05], dtype=float32), 'label_pos': array([11]), 'label': 'lost_or_stolen_card'}
Tasks.dialog_intent_prediction: ['pred', 'label_pos', 'label'],

# sys : ['you', 'are', 'welcome', '.', 'have', 'a', 'great', 'day', '!']
Tasks.dialog_modeling: ['sys'],

# ============ audio tasks =================== # ============ audio tasks ===================


# audio processed for single file in PCM format # audio processed for single file in PCM format


+ 7
- 0
modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py View File

@@ -3,6 +3,8 @@
import os import os
from typing import Any, Dict from typing import Any, Dict


import json

from ...metainfo import Preprocessors from ...metainfo import Preprocessors
from ...utils.config import Config from ...utils.config import Config
from ...utils.constant import Fields, ModelFile from ...utils.constant import Fields, ModelFile
@@ -32,6 +34,11 @@ class DialogIntentPredictionPreprocessor(Preprocessor):
self.text_field = IntentBPETextField( self.text_field = IntentBPETextField(
self.model_dir, config=self.config) self.model_dir, config=self.config)


self.categories = None
with open(os.path.join(self.model_dir, 'categories.json'), 'r') as f:
self.categories = json.load(f)
assert len(self.categories) == 77

@type_assert(object, str) @type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]: def __call__(self, data: str) -> Dict[str, Any]:
"""process the raw input data """process the raw input data


Loading…
Cancel
Save