Browse Source

fix bug

master
ly119399 3 years ago
parent
commit
d422510158
24 changed files with 45 additions and 102 deletions
  1. +4
    -4
      modelscope/models/nlp/space/dialog_intent_prediction_model.py
  2. +4
    -4
      modelscope/models/nlp/space/dialog_modeling_model.py
  3. +2
    -3
      modelscope/models/nlp/space/model/gen_unified_transformer.py
  4. +1
    -3
      modelscope/models/nlp/space/model/generator.py
  5. +2
    -3
      modelscope/models/nlp/space/model/intent_unified_transformer.py
  6. +2
    -3
      modelscope/models/nlp/space/model/model_base.py
  7. +1
    -3
      modelscope/models/nlp/space/model/unified_transformer.py
  8. +1
    -3
      modelscope/models/nlp/space/modules/embedder.py
  9. +1
    -3
      modelscope/models/nlp/space/modules/feedforward.py
  10. +1
    -3
      modelscope/models/nlp/space/modules/functions.py
  11. +1
    -3
      modelscope/models/nlp/space/modules/multihead_attention.py
  12. +1
    -3
      modelscope/models/nlp/space/modules/transformer_block.py
  13. +4
    -3
      modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py
  14. +4
    -3
      modelscope/pipelines/nlp/dialog_modeling_pipeline.py
  15. +2
    -4
      modelscope/preprocessors/space/fields/gen_field.py
  16. +2
    -3
      modelscope/preprocessors/space/fields/intent_field.py
  17. +0
    -8
      modelscope/utils/nlp/space/db_ops.py
  18. +0
    -13
      modelscope/utils/nlp/space/ontology.py
  19. +0
    -18
      modelscope/utils/nlp/space/utils.py
  20. +0
    -1
      requirements/nlp.txt
  21. +5
    -3
      tests/pipelines/test_dialog_intent_prediction.py
  22. +5
    -6
      tests/pipelines/test_dialog_modeling.py
  23. +1
    -1
      tests/pipelines/test_nli.py
  24. +1
    -1
      tests/pipelines/test_sentiment_classification.py

+ 4
- 4
modelscope/models/nlp/space/dialog_intent_prediction_model.py View File

@@ -1,3 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
from typing import Any, Dict

@@ -11,20 +13,18 @@ from ...builder import MODELS
from .model.generator import Generator
from .model.model_base import SpaceModelBase

__all__ = ['SpaceForDialogIntentModel']
__all__ = ['SpaceForDialogIntent']


@MODELS.register_module(
Tasks.dialog_intent_prediction, module_name=Models.space)
class SpaceForDialogIntentModel(Model):
class SpaceForDialogIntent(Model):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the test generation model from the `model_dir` path.

Args:
model_dir (str): the model path.
model_cls (Optional[Any], optional): model loader, if None, use the
default loader to load model weights, by default None.
"""

super().__init__(model_dir, *args, **kwargs)


+ 4
- 4
modelscope/models/nlp/space/dialog_modeling_model.py View File

@@ -1,3 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
from typing import Any, Dict, Optional

@@ -11,19 +13,17 @@ from ...builder import MODELS
from .model.generator import Generator
from .model.model_base import SpaceModelBase

__all__ = ['SpaceForDialogModelingModel']
__all__ = ['SpaceForDialogModeling']


@MODELS.register_module(Tasks.dialog_modeling, module_name=Models.space)
class SpaceForDialogModelingModel(Model):
class SpaceForDialogModeling(Model):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the test generation model from the `model_dir` path.

Args:
model_dir (str): the model path.
model_cls (Optional[Any], optional): model loader, if None, use the
default loader to load model weights, by default None.
"""

super().__init__(model_dir, *args, **kwargs)


+ 2
- 3
modelscope/models/nlp/space/model/gen_unified_transformer.py View File

@@ -1,6 +1,5 @@
"""
IntentUnifiedTransformer
"""
# Copyright (c) Alibaba, Inc. and its affiliates.

import torch

from .unified_transformer import UnifiedTransformer


+ 1
- 3
modelscope/models/nlp/space/model/generator.py View File

@@ -1,6 +1,4 @@
"""
Generator class.
"""
# Copyright (c) Alibaba, Inc. and its affiliates.

import math



+ 2
- 3
modelscope/models/nlp/space/model/intent_unified_transformer.py View File

@@ -1,6 +1,5 @@
"""
IntentUnifiedTransformer
"""
# Copyright (c) Alibaba, Inc. and its affiliates.

import torch
import torch.nn as nn
import torch.nn.functional as F


+ 2
- 3
modelscope/models/nlp/space/model/model_base.py View File

@@ -1,6 +1,5 @@
"""
Model base
"""
# Copyright (c) Alibaba, Inc. and its affiliates.

import os

import torch.nn as nn


+ 1
- 3
modelscope/models/nlp/space/model/unified_transformer.py View File

@@ -1,6 +1,4 @@
"""
UnifiedTransformer
"""
# Copyright (c) Alibaba, Inc. and its affiliates.

import numpy as np
import torch


+ 1
- 3
modelscope/models/nlp/space/modules/embedder.py View File

@@ -1,6 +1,4 @@
"""
Embedder class.
"""
# Copyright (c) Alibaba, Inc. and its affiliates.

import torch
import torch.nn as nn


+ 1
- 3
modelscope/models/nlp/space/modules/feedforward.py View File

@@ -1,6 +1,4 @@
"""
FeedForward class.
"""
# Copyright (c) Alibaba, Inc. and its affiliates.

import torch
import torch.nn as nn


+ 1
- 3
modelscope/models/nlp/space/modules/functions.py View File

@@ -1,6 +1,4 @@
"""
Helpful functions.
"""
# Copyright (c) Alibaba, Inc. and its affiliates.

import numpy as np
import torch


+ 1
- 3
modelscope/models/nlp/space/modules/multihead_attention.py View File

@@ -1,6 +1,4 @@
"""
MultiheadAttention class.
"""
# Copyright (c) Alibaba, Inc. and its affiliates.

import torch
import torch.nn as nn


+ 1
- 3
modelscope/models/nlp/space/modules/transformer_block.py View File

@@ -1,6 +1,4 @@
"""
TransformerBlock class.
"""
# Copyright (c) Alibaba, Inc. and its affiliates.

import torch
import torch.nn as nn


+ 4
- 3
modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py View File

@@ -1,7 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Any, Dict

from ...metainfo import Pipelines
from ...models.nlp import SpaceForDialogIntentModel
from ...models.nlp import SpaceForDialogIntent
from ...preprocessors import DialogIntentPredictionPreprocessor
from ...utils.constant import Tasks
from ..base import Pipeline
@@ -15,7 +17,7 @@ __all__ = ['DialogIntentPredictionPipeline']
module_name=Pipelines.dialog_intent_prediction)
class DialogIntentPredictionPipeline(Pipeline):

def __init__(self, model: SpaceForDialogIntentModel,
def __init__(self, model: SpaceForDialogIntent,
preprocessor: DialogIntentPredictionPreprocessor, **kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction

@@ -26,7 +28,6 @@ class DialogIntentPredictionPipeline(Pipeline):

super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.model = model
# self.tokenizer = preprocessor.tokenizer

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""process the prediction results


+ 4
- 3
modelscope/pipelines/nlp/dialog_modeling_pipeline.py View File

@@ -1,7 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Any, Dict, Optional

from ...metainfo import Pipelines
from ...models.nlp import SpaceForDialogModelingModel
from ...models.nlp import SpaceForDialogModeling
from ...preprocessors import DialogModelingPreprocessor
from ...utils.constant import Tasks
from ..base import Pipeline, Tensor
@@ -14,7 +16,7 @@ __all__ = ['DialogModelingPipeline']
Tasks.dialog_modeling, module_name=Pipelines.dialog_modeling)
class DialogModelingPipeline(Pipeline):

def __init__(self, model: SpaceForDialogModelingModel,
def __init__(self, model: SpaceForDialogModeling,
preprocessor: DialogModelingPreprocessor, **kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction

@@ -40,7 +42,6 @@ class DialogModelingPipeline(Pipeline):
inputs['resp'])
assert len(sys_rsp) > 2
sys_rsp = sys_rsp[1:len(sys_rsp) - 1]
# sys_rsp = self.preprocessor.text_field.tokenizer.

inputs['sys'] = sys_rsp



+ 2
- 4
modelscope/preprocessors/space/fields/gen_field.py View File

@@ -1,6 +1,5 @@
"""
Field class
"""
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
import random
from collections import OrderedDict
@@ -8,7 +7,6 @@ from itertools import chain

import numpy as np

from ....utils.constant import ModelFile
from ....utils.nlp.space import ontology, utils
from ....utils.nlp.space.db_ops import MultiWozDB
from ....utils.nlp.space.utils import list2np


+ 2
- 3
modelscope/preprocessors/space/fields/intent_field.py View File

@@ -1,6 +1,5 @@
"""
Intent Field class
"""
# Copyright (c) Alibaba, Inc. and its affiliates.

import glob
import multiprocessing
import os


+ 0
- 8
modelscope/utils/nlp/space/db_ops.py View File

@@ -308,14 +308,6 @@ if __name__ == '__main__':
'attraction': 5,
'train': 1,
}
# for ent in res:
# if reidx.get(domain):
# report.append(ent[reidx[domain]])
# for ent in res:
# if 'name' in ent:
# report.append(ent['name'])
# if 'trainid' in ent:
# report.append(ent['trainid'])
print(constraints)
print(res)
print('count:', len(res), '\nnames:', report)

+ 0
- 13
modelscope/utils/nlp/space/ontology.py View File

@@ -123,19 +123,6 @@ dialog_act_all_slots = all_slots + ['choice', 'open']
# no need of this, just covert slot to [slot] e.g. pricerange -> [pricerange]
slot_name_to_slot_token = {}

# special slot tokens in responses
# not use at the momoent
slot_name_to_value_token = {
# 'entrance fee': '[value_price]',
# 'pricerange': '[value_price]',
# 'arriveby': '[value_time]',
# 'leaveat': '[value_time]',
# 'departure': '[value_place]',
# 'destination': '[value_place]',
# 'stay': 'count',
# 'people': 'count'
}

# eos tokens definition
eos_tokens = {
'user': '<eos_u>',


+ 0
- 18
modelscope/utils/nlp/space/utils.py View File

@@ -53,16 +53,9 @@ def clean_replace(s, r, t, forward=True, backward=False):
return s, -1
return s[:idx] + t + s[idx_r:], idx_r

# source, replace, target = s, r, t
# count = 0
sidx = 0
while sidx != -1:
s, sidx = clean_replace_single(s, r, t, forward, backward, sidx)
# count += 1
# print(s, sidx)
# if count == 20:
# print(source, '\n', replace, '\n', target)
# quit()
return s


@@ -193,14 +186,3 @@ class MultiWOZVocab(object):
return self._idx2word[idx]
else:
return self._idx2word[idx] + '(o)'

# def sentence_decode(self, index_list, eos=None, indicate_oov=False):
# l = [self.decode(_, indicate_oov) for _ in index_list]
# if not eos or eos not in l:
# return ' '.join(l)
# else:
# idx = l.index(eos)
# return ' '.join(l[:idx])
#
# def nl_decode(self, l, eos=None):
# return [self.sentence_decode(_, eos) + '\n' for _ in l]

+ 0
- 1
requirements/nlp.txt View File

@@ -1,4 +1,3 @@
https://alinlp.alibaba-inc.com/pypi/sofa-1.0.3-py3-none-any.whl
https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz
spacy>=2.3.5
# python -m spacy download en_core_web_sm

+ 5
- 3
tests/pipelines/test_dialog_intent_prediction.py View File

@@ -3,10 +3,11 @@ import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import SpaceForDialogIntentModel
from modelscope.models.nlp import SpaceForDialogIntent
from modelscope.pipelines import DialogIntentPredictionPipeline, pipeline
from modelscope.preprocessors import DialogIntentPredictionPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class DialogIntentPredictionTest(unittest.TestCase):
@@ -16,11 +17,11 @@ class DialogIntentPredictionTest(unittest.TestCase):
'I still have not received my new card, I ordered over a week ago.'
]

@unittest.skip('test with snapshot_download')
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run(self):
cache_path = snapshot_download(self.model_id)
preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path)
model = SpaceForDialogIntentModel(
model = SpaceForDialogIntent(
model_dir=cache_path,
text_field=preprocessor.text_field,
config=preprocessor.config)
@@ -37,6 +38,7 @@ class DialogIntentPredictionTest(unittest.TestCase):
for my_pipeline, item in list(zip(pipelines, self.test_case)):
print(my_pipeline(item))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
preprocessor = DialogIntentPredictionPreprocessor(


+ 5
- 6
tests/pipelines/test_dialog_modeling.py View File

@@ -1,15 +1,13 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
import tempfile
import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import SpaceForDialogModelingModel
from modelscope.models.nlp import SpaceForDialogModeling
from modelscope.pipelines import DialogModelingPipeline, pipeline
from modelscope.preprocessors import DialogModelingPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class DialogModelingTest(unittest.TestCase):
@@ -91,13 +89,13 @@ class DialogModelingTest(unittest.TestCase):
}
}

@unittest.skip('test with snapshot_download')
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run(self):

cache_path = snapshot_download(self.model_id)

preprocessor = DialogModelingPreprocessor(model_dir=cache_path)
model = SpaceForDialogModelingModel(
model = SpaceForDialogModeling(
model_dir=cache_path,
text_field=preprocessor.text_field,
config=preprocessor.config)
@@ -120,6 +118,7 @@ class DialogModelingTest(unittest.TestCase):
})
print('sys : {}'.format(result['sys']))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
preprocessor = DialogModelingPreprocessor(model_dir=model.model_dir)


+ 1
- 1
tests/pipelines/test_nli.py View File

@@ -37,7 +37,7 @@ class NLITest(unittest.TestCase):
task=Tasks.nli, model=model, preprocessor=tokenizer)
print(pipeline_ins(input=(self.sentence1, self.sentence2)))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_model_name(self):
pipeline_ins = pipeline(task=Tasks.nli, model=self.model_id)
print(pipeline_ins(input=(self.sentence1, self.sentence2)))


+ 1
- 1
tests/pipelines/test_sentiment_classification.py View File

@@ -42,7 +42,7 @@ class SentimentClassificationTest(unittest.TestCase):
preprocessor=tokenizer)
print(pipeline_ins(input=self.sentence1))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.sentiment_classification, model=self.model_id)


Loading…
Cancel
Save