Browse Source

unfiinished change

master
雨泓 3 years ago
parent
commit
c5693f2a84
29 changed files with 215 additions and 341 deletions
  1. +5
    -1
      modelscope/metainfo.py
  2. +10
    -2
      modelscope/models/__init__.py
  3. +3
    -3
      modelscope/models/nlp/__init__.py
  4. +19
    -18
      modelscope/models/nlp/masked_language_model.py
  5. +0
    -84
      modelscope/models/nlp/nli_model.py
  6. +3
    -4
      modelscope/models/nlp/palm_for_text_generation.py
  7. +21
    -0
      modelscope/models/nlp/sbert_for_nli.py
  8. +3
    -68
      modelscope/models/nlp/sbert_for_sentence_similarity.py
  9. +22
    -0
      modelscope/models/nlp/sbert_for_sentiment_classification.py
  10. +55
    -0
      modelscope/models/nlp/sbert_for_sequence_classification.py
  11. +4
    -3
      modelscope/models/nlp/sbert_for_token_classification.py
  12. +5
    -7
      modelscope/models/nlp/sbert_for_zero_shot_classification.py
  13. +0
    -85
      modelscope/models/nlp/sentiment_classification_model.py
  14. +4
    -4
      modelscope/models/nlp/space/dialog_intent_prediction_model.py
  15. +4
    -4
      modelscope/models/nlp/space/dialog_modeling_model.py
  16. +1
    -1
      modelscope/models/nlp/space/model/gen_unified_transformer.py
  17. +1
    -1
      modelscope/models/nlp/space/model/intent_unified_transformer.py
  18. +3
    -3
      modelscope/models/nlp/space/model/unified_transformer.py
  19. +2
    -2
      modelscope/models/nlp/space/modules/transformer_block.py
  20. +5
    -5
      modelscope/pipelines/nlp/sentence_similarity_pipeline.py
  21. +5
    -5
      modelscope/pipelines/nlp/space/dialog_intent_prediction_pipeline.py
  22. +6
    -6
      modelscope/pipelines/nlp/zero_shot_classification_pipeline.py
  23. +15
    -14
      modelscope/preprocessors/nlp.py
  24. +4
    -4
      modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py
  25. +5
    -6
      modelscope/preprocessors/space/dialog_modeling_preprocessor.py
  26. +4
    -4
      modelscope/preprocessors/space/fields/gen_field.py
  27. +4
    -4
      modelscope/preprocessors/space/fields/intent_field.py
  28. +1
    -1
      modelscope/trainers/nlp/space/trainers/gen_trainer.py
  29. +1
    -2
      modelscope/trainers/nlp/space/trainers/intent_trainer.py

+ 5
- 1
modelscope/metainfo.py View File

@@ -13,8 +13,9 @@ class Models(object):

# nlp models
bert = 'bert'
palm2_0 = 'palm2.0'
palm = 'palm_v2'
structbert = 'structbert'
veco = 'veco'

# audio models
sambert_hifi_16k = 'sambert-hifi-16k'
@@ -85,6 +86,9 @@ class Preprocessors(object):
bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer'
palm_text_gen_tokenizer = 'palm-text-gen-tokenizer'
sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer'
sbert_nli_tokenizer = 'sbert-nli-tokenizer'
sbert_sen_cls_tokenizer = 'sbert-sen-cls-tokenizer'
sbert_zero_shot_cls_tokenizer = 'sbert-zero-shot-cls-tokenizer'

# audio preprocessor
linear_aec_fbank = 'linear-aec-fbank'


+ 10
- 2
modelscope/models/__init__.py View File

@@ -5,5 +5,13 @@ from .audio.tts.vocoder import Hifigan16k
from .base import Model
from .builder import MODELS, build_model
from .multi_model import OfaForImageCaptioning
from .nlp import (BertForSequenceClassification, SbertForNLI,
SbertForSentenceSimilarity)
from .nlp import (
BertForSequenceClassification,
SbertForNLI,
SbertForSentenceSimilarity,
SbertForSentimentClassification,
SbertForZeroShotClassification,
StructBertForMaskedLM,
VecoForMaskedLM,
StructBertForTokenClassification,
)

+ 3
- 3
modelscope/models/nlp/__init__.py View File

@@ -1,10 +1,10 @@
from .bert_for_sequence_classification import * # noqa F403
from .masked_language_model import * # noqa F403
from .nli_model import * # noqa F403
from .sbert_for_nli import * # noqa F403
from .palm_for_text_generation import * # noqa F403
from .sbert_for_sentence_similarity import * # noqa F403
from .sbert_for_token_classification import * # noqa F403
from .sentiment_classification_model import * # noqa F403
from .sbert_for_sentiment_classification import * # noqa F403
from .space.dialog_intent_prediction_model import * # noqa F403
from .space.dialog_modeling_model import * # noqa F403
from .zero_shot_classification_model import * # noqa F403
from .sbert_for_zero_shot_classification import * # noqa F403

+ 19
- 18
modelscope/models/nlp/masked_language_model.py View File

@@ -5,26 +5,25 @@ import numpy as np
from ...utils.constant import Tasks
from ..base import Model, Tensor
from ..builder import MODELS
from ...metainfo import Models

__all__ = ['StructBertForMaskedLM', 'VecoForMaskedLM']
__all__ = ['StructBertForMaskedLM', 'VecoForMaskedLM', 'MaskedLMModelBase']


class AliceMindBaseForMaskedLM(Model):
class MaskedLMModelBase(Model):

def __init__(self, model_dir: str, *args, **kwargs):
from sofa.utils.backend import AutoConfig, AutoModelForMaskedLM
self.model_dir = model_dir
super().__init__(model_dir, *args, **kwargs)
self.model = self.build_model()

self.config = AutoConfig.from_pretrained(model_dir)
self.model = AutoModelForMaskedLM.from_pretrained(
model_dir, config=self.config)
def build_model(self):
raise NotImplementedError()

def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, np.ndarray]:
"""return the result by the model

Args:
input (Dict[str, Any]): the preprocessed data
inputs (Dict[str, Any]): the preprocessed data

Returns:
Dict[str, np.ndarray]: results
@@ -36,15 +35,17 @@ class AliceMindBaseForMaskedLM(Model):
return {'logits': rst['logits'], 'input_ids': inputs['input_ids']}


@MODELS.register_module(Tasks.fill_mask, module_name=r'sbert')
class StructBertForMaskedLM(AliceMindBaseForMaskedLM):
# The StructBert for MaskedLM uses the same underlying model structure
# as the base model class.
pass
@MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert)
class StructBertForMaskedLM(MaskedLMModelBase):

def build_model(self):
from sofa import SbertForMaskedLM
return SbertForMaskedLM.from_pretrained(self.model_dir)

@MODELS.register_module(Tasks.fill_mask, module_name=r'veco')
class VecoForMaskedLM(AliceMindBaseForMaskedLM):
# The Veco for MaskedLM uses the same underlying model structure
# as the base model class.
pass

@MODELS.register_module(Tasks.fill_mask, module_name=Models.veco)
class VecoForMaskedLM(MaskedLMModelBase):

def build_model(self):
from sofa import VecoForMaskedLM
return VecoForMaskedLM.from_pretrained(self.model_dir)

+ 0
- 84
modelscope/models/nlp/nli_model.py View File

@@ -1,84 +0,0 @@
import os
from typing import Any, Dict

import numpy as np
import torch
from sofa import SbertConfig, SbertModel
from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel
from torch import nn
from transformers.activations import ACT2FN, get_activation
from transformers.models.bert.modeling_bert import SequenceClassifierOutput

from modelscope.utils.constant import Tasks
from ..base import Model, Tensor
from ..builder import MODELS

__all__ = ['SbertForNLI']


class SbertTextClassifier(SbertPreTrainedModel):

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.encoder = SbertModel(config, add_pooling_layer=True)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

def forward(self, input_ids=None, token_type_ids=None):
outputs = self.encoder(
input_ids,
token_type_ids=token_type_ids,
return_dict=None,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits


@MODELS.register_module(
Tasks.nli, module_name=r'nlp_structbert_nli_chinese-base')
class SbertForNLI(Model):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the text generation model from the `model_dir` path.

Args:
model_dir (str): the model path.
model_cls (Optional[Any], optional): model loader, if None, use the
default loader to load model weights, by default None.
"""
super().__init__(model_dir, *args, **kwargs)
self.model_dir = model_dir

self.model = SbertTextClassifier.from_pretrained(
model_dir, num_labels=3)
self.model.eval()

def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
"""return the result by the model

Args:
input (Dict[str, Any]): the preprocessed data

Returns:
Dict[str, np.ndarray]: results
Example:
{
'predictions': array([1]), # lable 0-negative 1-positive
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
}
"""
input_ids = torch.tensor(input['input_ids'], dtype=torch.long)
token_type_ids = torch.tensor(
input['token_type_ids'], dtype=torch.long)
with torch.no_grad():
logits = self.model(input_ids, token_type_ids)
probs = logits.softmax(-1).numpy()
pred = logits.argmax(-1).numpy()
logits = logits.numpy()
res = {'predictions': pred, 'probabilities': probs, 'logits': logits}
return res

+ 3
- 4
modelscope/models/nlp/palm_for_text_generation.py View File

@@ -1,14 +1,14 @@
from typing import Dict

from modelscope.metainfo import Models
from modelscope.utils.constant import Tasks
from ...metainfo import Models
from ...utils.constant import Tasks
from ..base import Model, Tensor
from ..builder import MODELS

__all__ = ['PalmForTextGeneration']


@MODELS.register_module(Tasks.text_generation, module_name=Models.palm2_0)
@MODELS.register_module(Tasks.text_generation, module_name=Models.palm)
class PalmForTextGeneration(Model):

def __init__(self, model_dir: str, *args, **kwargs):
@@ -20,7 +20,6 @@ class PalmForTextGeneration(Model):
default loader to load model weights, by default None.
"""
super().__init__(model_dir, *args, **kwargs)
self.model_dir = model_dir

from sofa.models.palm_v2 import PalmForConditionalGeneration, Translator
model = PalmForConditionalGeneration.from_pretrained(model_dir)


+ 21
- 0
modelscope/models/nlp/sbert_for_nli.py View File

@@ -0,0 +1,21 @@
from modelscope.utils.constant import Tasks
from .sbert_for_sequence_classification import SbertForSequenceClassificationBase
from ..builder import MODELS
from ...metainfo import Models

__all__ = ['SbertForNLI']


@MODELS.register_module(Tasks.nli, module_name=Models.structbert)
class SbertForNLI(SbertForSequenceClassificationBase):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the text generation model from the `model_dir` path.

Args:
model_dir (str): the model path.
model_cls (Optional[Any], optional): model loader, if None, use the
default loader to load model weights, by default None.
"""
super().__init__(model_dir, *args, **kwargs)
assert self.model.config.num_labels == 3

+ 3
- 68
modelscope/models/nlp/sbert_for_sentence_similarity.py View File

@@ -1,46 +1,14 @@
import os
from typing import Any, Dict

import json
import numpy as np
import torch
from sofa import SbertModel
from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel
from torch import nn

from modelscope.metainfo import Models
from modelscope.utils.constant import Tasks
from ..base import Model, Tensor
from .sbert_for_sequence_classification import SbertForSequenceClassificationBase
from ..builder import MODELS

__all__ = ['SbertForSentenceSimilarity']


class SbertTextClassifier(SbertPreTrainedModel):

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.encoder = SbertModel(config, add_pooling_layer=True)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

def forward(self, input_ids=None, token_type_ids=None):
outputs = self.encoder(
input_ids,
token_type_ids=token_type_ids,
return_dict=None,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits


@MODELS.register_module(
Tasks.sentence_similarity, module_name=Models.structbert)
class SbertForSentenceSimilarity(Model):
class SbertForSentenceSimilarity(SbertForSequenceClassificationBase):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the sentence similarity model from the `model_dir` path.
@@ -52,37 +20,4 @@ class SbertForSentenceSimilarity(Model):
"""
super().__init__(model_dir, *args, **kwargs)
self.model_dir = model_dir

self.model = SbertTextClassifier.from_pretrained(
model_dir, num_labels=2)
self.model.eval()
self.label_path = os.path.join(self.model_dir, 'label_mapping.json')
with open(self.label_path) as f:
self.label_mapping = json.load(f)
self.id2label = {idx: name for name, idx in self.label_mapping.items()}

def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
"""return the result by the model

Args:
input (Dict[str, Any]): the preprocessed data

Returns:
Dict[str, np.ndarray]: results
Example:
{
'predictions': array([1]), # lable 0-negative 1-positive
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
}
"""
input_ids = torch.tensor(input['input_ids'], dtype=torch.long)
token_type_ids = torch.tensor(
input['token_type_ids'], dtype=torch.long)
with torch.no_grad():
logits = self.model(input_ids, token_type_ids)
probs = logits.softmax(-1).numpy()
pred = logits.argmax(-1).numpy()
logits = logits.numpy()
res = {'predictions': pred, 'probabilities': probs, 'logits': logits}
return res
assert self.model.config.num_labels == 2

+ 22
- 0
modelscope/models/nlp/sbert_for_sentiment_classification.py View File

@@ -0,0 +1,22 @@
from modelscope.utils.constant import Tasks
from .sbert_for_sequence_classification import SbertForSequenceClassificationBase
from ..builder import MODELS

__all__ = ['SbertForSentimentClassification']


@MODELS.register_module(
Tasks.sentiment_classification,
module_name=r'sbert-sentiment-classification')
class SbertForSentimentClassification(SbertForSequenceClassificationBase):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the text generation model from the `model_dir` path.

Args:
model_dir (str): the model path.
model_cls (Optional[Any], optional): model loader, if None, use the
default loader to load model weights, by default None.
"""
super().__init__(model_dir, *args, **kwargs)
assert self.model.config.num_labels == 2

+ 55
- 0
modelscope/models/nlp/sbert_for_sequence_classification.py View File

@@ -0,0 +1,55 @@
from torch import nn
from typing import Any, Dict
from ..base import Model
import numpy as np
import json
import os
from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel, SbertModel


class SbertTextClassfier(SbertPreTrainedModel):

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.encoder = SbertModel(config, add_pooling_layer=True)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

def forward(self, input_ids=None, token_type_ids=None):
outputs = self.encoder(
input_ids,
token_type_ids=token_type_ids,
return_dict=None,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return {
"logits": logits
}


class SbertForSequenceClassificationBase(Model):

def __init__(self, model_dir: str, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
self.model = SbertTextClassfier.from_pretrained(model_dir)
self.id2label = {}
self.label_path = os.path.join(self.model_dir, 'label_mapping.json')
if os.path.exists(self.label_path):
with open(self.label_path) as f:
self.label_mapping = json.load(f)
self.id2label = {idx: name for name, idx in self.label_mapping.items()}

def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
return self.model.forward(input)

def postprocess(self, input, **kwargs):
logits = input["logits"]
probs = logits.softmax(-1).numpy()
pred = logits.argmax(-1).numpy()
logits = logits.numpy()
res = {'predictions': pred, 'probabilities': probs, 'logits': logits}
return res

+ 4
- 3
modelscope/models/nlp/sbert_for_token_classification.py View File

@@ -46,10 +46,11 @@ class StructBertForTokenClassification(Model):
}
"""
input_ids = torch.tensor(input['input_ids']).unsqueeze(0)
output = self.model(input_ids)
logits = output.logits
return self.model(input_ids)

def postprocess(self, input: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
logits = input["logits"]
pred = torch.argmax(logits[0], dim=-1)
pred = pred.numpy()

rst = {'predictions': pred, 'logits': logits, 'text': input['text']}
return rst

modelscope/models/nlp/zero_shot_classification_model.py → modelscope/models/nlp/sbert_for_zero_shot_classification.py View File

@@ -1,19 +1,19 @@
from typing import Any, Dict

import numpy as np
import torch

from modelscope.utils.constant import Tasks
from ..base import Model
from ..builder import MODELS
from ...metainfo import Models

__all__ = ['BertForZeroShotClassification']
__all__ = ['SbertForZeroShotClassification']


@MODELS.register_module(
Tasks.zero_shot_classification,
module_name=r'bert-zero-shot-classification')
class BertForZeroShotClassification(Model):
module_name=Models.structbert)
class SbertForZeroShotClassification(Model):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the zero shot classification model from the `model_dir` path.
@@ -25,7 +25,6 @@ class BertForZeroShotClassification(Model):
super().__init__(model_dir, *args, **kwargs)
from sofa import SbertForSequenceClassification
self.model = SbertForSequenceClassification.from_pretrained(model_dir)
self.model.eval()

def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
"""return the result by the model
@@ -40,8 +39,7 @@ class BertForZeroShotClassification(Model):
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
}
"""
with torch.no_grad():
outputs = self.model(**input)
outputs = self.model(**input)
logits = outputs['logits'].numpy()
res = {'logits': logits}
return res

+ 0
- 85
modelscope/models/nlp/sentiment_classification_model.py View File

@@ -1,85 +0,0 @@
import os
from typing import Any, Dict

import numpy as np
import torch
from sofa import SbertConfig, SbertModel
from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel
from torch import nn
from transformers.activations import ACT2FN, get_activation
from transformers.models.bert.modeling_bert import SequenceClassifierOutput

from modelscope.utils.constant import Tasks
from ..base import Model, Tensor
from ..builder import MODELS

__all__ = ['SbertForSentimentClassification']


class SbertTextClassifier(SbertPreTrainedModel):

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.encoder = SbertModel(config, add_pooling_layer=True)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

def forward(self, input_ids=None, token_type_ids=None):
outputs = self.encoder(
input_ids,
token_type_ids=token_type_ids,
return_dict=None,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits


@MODELS.register_module(
Tasks.sentiment_classification,
module_name=r'sbert-sentiment-classification')
class SbertForSentimentClassification(Model):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the text generation model from the `model_dir` path.

Args:
model_dir (str): the model path.
model_cls (Optional[Any], optional): model loader, if None, use the
default loader to load model weights, by default None.
"""
super().__init__(model_dir, *args, **kwargs)
self.model_dir = model_dir

self.model = SbertTextClassifier.from_pretrained(
model_dir, num_labels=2)
self.model.eval()

def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
"""return the result by the model

Args:
input (Dict[str, Any]): the preprocessed data

Returns:
Dict[str, np.ndarray]: results
Example:
{
'predictions': array([1]), # lable 0-negative 1-positive
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
}
"""
input_ids = torch.tensor(input['input_ids'], dtype=torch.long)
token_type_ids = torch.tensor(
input['token_type_ids'], dtype=torch.long)
with torch.no_grad():
logits = self.model(input_ids, token_type_ids)
probs = logits.softmax(-1).numpy()
pred = logits.argmax(-1).numpy()
logits = logits.numpy()
res = {'predictions': pred, 'probabilities': probs, 'logits': logits}
return res

+ 4
- 4
modelscope/models/nlp/space/dialog_intent_prediction_model.py View File

@@ -1,11 +1,11 @@
import os
from typing import Any, Dict

from modelscope.preprocessors.space.fields.intent_field import \
from ....preprocessors.space.fields.intent_field import \
IntentBPETextField
from modelscope.trainers.nlp.space.trainers.intent_trainer import IntentTrainer
from modelscope.utils.config import Config
from modelscope.utils.constant import Tasks
from ....trainers.nlp.space.trainers.intent_trainer import IntentTrainer
from ....utils.config import Config
from ....utils.constant import Tasks
from ...base import Model, Tensor
from ...builder import MODELS
from .model.generator import Generator


+ 4
- 4
modelscope/models/nlp/space/dialog_modeling_model.py View File

@@ -1,11 +1,11 @@
import os
from typing import Any, Dict, Optional

from modelscope.preprocessors.space.fields.gen_field import \
from ....preprocessors.space.fields.gen_field import \
MultiWOZBPETextField
from modelscope.trainers.nlp.space.trainers.gen_trainer import MultiWOZTrainer
from modelscope.utils.config import Config
from modelscope.utils.constant import Tasks
from ....trainers.nlp.space.trainers.gen_trainer import MultiWOZTrainer
from ....utils.config import Config
from ....utils.constant import Tasks
from ...base import Model, Tensor
from ...builder import MODELS
from .model.generator import Generator


+ 1
- 1
modelscope/models/nlp/space/model/gen_unified_transformer.py View File

@@ -3,7 +3,7 @@ IntentUnifiedTransformer
"""
import torch

from modelscope.models.nlp.space.model.unified_transformer import \
from .unified_transformer import \
UnifiedTransformer




+ 1
- 1
modelscope/models/nlp/space/model/intent_unified_transformer.py View File

@@ -5,7 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F

from modelscope.utils.nlp.space.criterions import compute_kl_loss
from .....utils.nlp.space.criterions import compute_kl_loss
from .unified_transformer import UnifiedTransformer




+ 3
- 3
modelscope/models/nlp/space/model/unified_transformer.py View File

@@ -7,9 +7,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F

from modelscope.models.nlp.space.model.model_base import ModelBase
from modelscope.models.nlp.space.modules.embedder import Embedder
from modelscope.models.nlp.space.modules.transformer_block import \
from .model_base import ModelBase
from ..modules.embedder import Embedder
from ..modules.transformer_block import \
TransformerBlock




+ 2
- 2
modelscope/models/nlp/space/modules/transformer_block.py View File

@@ -5,8 +5,8 @@ TransformerBlock class.
import torch
import torch.nn as nn

from modelscope.models.nlp.space.modules.feedforward import FeedForward
from modelscope.models.nlp.space.modules.multihead_attention import \
from .feedforward import FeedForward
from .multihead_attention import \
MultiheadAttention




+ 5
- 5
modelscope/pipelines/nlp/sentence_similarity_pipeline.py View File

@@ -2,10 +2,10 @@ from typing import Any, Dict, Union

import numpy as np

from modelscope.metainfo import Pipelines
from modelscope.models.nlp import SbertForSentenceSimilarity
from modelscope.preprocessors import SequenceClassificationPreprocessor
from modelscope.utils.constant import Tasks
from ...metainfo import Pipelines
from ...models.nlp import SbertForSentenceSimilarity
from ...preprocessors import SequenceClassificationPreprocessor
from ...utils.constant import Tasks
from ...models import Model
from ..base import Input, Pipeline
from ..builder import PIPELINES
@@ -18,7 +18,7 @@ __all__ = ['SentenceSimilarityPipeline']
class SentenceSimilarityPipeline(Pipeline):

def __init__(self,
model: Union[SbertForSentenceSimilarity, str],
model: Union[Model, str],
preprocessor: SequenceClassificationPreprocessor = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp sentence similarity pipeline for prediction


+ 5
- 5
modelscope/pipelines/nlp/space/dialog_intent_prediction_pipeline.py View File

@@ -1,10 +1,10 @@
from typing import Any, Dict, Optional
from typing import Any, Dict

from modelscope.models.nlp import DialogIntentModel
from modelscope.preprocessors import DialogIntentPredictionPreprocessor
from modelscope.utils.constant import Tasks
from ...base import Input, Pipeline
from ...base import Pipeline
from ...builder import PIPELINES
from ....models.nlp import DialogIntentModel
from ....preprocessors import DialogIntentPredictionPreprocessor
from ....utils.constant import Tasks

__all__ = ['DialogIntentPredictionPipeline']



+ 6
- 6
modelscope/pipelines/nlp/zero_shot_classification_pipeline.py View File

@@ -6,8 +6,8 @@ import json
import numpy as np
from scipy.special import softmax

from modelscope.models.nlp import BertForZeroShotClassification
from modelscope.preprocessors import ZeroShotClassificationPreprocessor
from modelscope.models.nlp import SbertForZeroShotClassification
from modelscope.preprocessors import SbertZeroShotClassificationPreprocessor
from modelscope.utils.constant import Tasks
from ...models import Model
from ..base import Input, Pipeline
@@ -22,8 +22,8 @@ __all__ = ['ZeroShotClassificationPipeline']
class ZeroShotClassificationPipeline(Pipeline):

def __init__(self,
model: Union[BertForZeroShotClassification, str],
preprocessor: ZeroShotClassificationPreprocessor = None,
model: Union[SbertForZeroShotClassification, str],
preprocessor: SbertZeroShotClassificationPreprocessor = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction

@@ -31,11 +31,11 @@ class ZeroShotClassificationPipeline(Pipeline):
model (SbertForSentimentClassification): a model instance
preprocessor (SentimentClassificationPreprocessor): a preprocessor instance
"""
assert isinstance(model, str) or isinstance(model, BertForZeroShotClassification), \
assert isinstance(model, str) or isinstance(model, SbertForZeroShotClassification), \
'model must be a single str or BertForZeroShotClassification'
sc_model = model if isinstance(
model,
BertForZeroShotClassification) else Model.from_pretrained(model)
SbertForZeroShotClassification) else Model.from_pretrained(model)

self.entailment_id = 0
self.contradiction_id = 2


+ 15
- 14
modelscope/preprocessors/nlp.py View File

@@ -5,17 +5,18 @@ from typing import Any, Dict, Union

from transformers import AutoTokenizer

from modelscope.metainfo import Preprocessors
from modelscope.utils.constant import Fields, InputFields
from modelscope.utils.type_assert import type_assert
from ..metainfo import Preprocessors
from ..metainfo import Models
from ..utils.constant import Fields, InputFields
from ..utils.type_assert import type_assert
from .base import Preprocessor
from .builder import PREPROCESSORS

__all__ = [
'Tokenize', 'SequenceClassificationPreprocessor',
'TextGenerationPreprocessor', 'ZeroShotClassificationPreprocessor',
'TokenClassifcationPreprocessor', 'NLIPreprocessor',
'SentimentClassificationPreprocessor', 'FillMaskPreprocessor'
'PalmTextGenerationPreprocessor', 'SbertZeroShotClassificationPreprocessor',
'SbertTokenClassifcationPreprocessor', 'SbertNLIPreprocessor',
'SbertSentimentClassificationPreprocessor', 'FillMaskPreprocessor'
]


@@ -34,8 +35,8 @@ class Tokenize(Preprocessor):


@PREPROCESSORS.register_module(
Fields.nlp, module_name=r'nlp_structbert_nli_chinese-base')
class NLIPreprocessor(Preprocessor):
Fields.nlp, module_name=Preprocessors.sbert_nli_tokenizer)
class SbertNLIPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path
@@ -104,8 +105,8 @@ class NLIPreprocessor(Preprocessor):


@PREPROCESSORS.register_module(
Fields.nlp, module_name=r'sbert-sentiment-classification')
class SentimentClassificationPreprocessor(Preprocessor):
Fields.nlp, module_name=Preprocessors.sbert_sen_cls_tokenizer)
class SbertSentimentClassificationPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path
@@ -263,7 +264,7 @@ class SequenceClassificationPreprocessor(Preprocessor):

@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.palm_text_gen_tokenizer)
class TextGenerationPreprocessor(Preprocessor):
class PalmTextGenerationPreprocessor(Preprocessor):

def __init__(self, model_dir: str, tokenizer, *args, **kwargs):
"""preprocess the data using the vocab.txt from the `model_dir` path
@@ -373,8 +374,8 @@ class FillMaskPreprocessor(Preprocessor):


@PREPROCESSORS.register_module(
Fields.nlp, module_name=r'bert-zero-shot-classification')
class ZeroShotClassificationPreprocessor(Preprocessor):
Fields.nlp, module_name=Preprocessors.sbert_zero_shot_cls_tokenizer)
class SbertZeroShotClassificationPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path
@@ -418,7 +419,7 @@ class ZeroShotClassificationPreprocessor(Preprocessor):

@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.sbert_token_cls_tokenizer)
class TokenClassifcationPreprocessor(Preprocessor):
class SbertTokenClassifcationPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path


+ 4
- 4
modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py View File

@@ -3,11 +3,11 @@
import os
from typing import Any, Dict

from modelscope.preprocessors.space.fields.intent_field import \
from .fields.intent_field import \
IntentBPETextField
from modelscope.utils.config import Config
from modelscope.utils.constant import Fields
from modelscope.utils.type_assert import type_assert
from ...utils.config import Config
from ...utils.constant import Fields
from ...utils.type_assert import type_assert
from ..base import Preprocessor
from ..builder import PREPROCESSORS



+ 5
- 6
modelscope/preprocessors/space/dialog_modeling_preprocessor.py View File

@@ -1,16 +1,15 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
import uuid
from typing import Any, Dict, Union
from typing import Any, Dict

from modelscope.preprocessors.space.fields.gen_field import \
from .fields.gen_field import \
MultiWOZBPETextField
from modelscope.utils.config import Config
from modelscope.utils.constant import Fields, InputFields
from modelscope.utils.type_assert import type_assert
from ..base import Preprocessor
from ..builder import PREPROCESSORS
from ...utils.config import Config
from ...utils.constant import Fields
from ...utils.type_assert import type_assert

__all__ = ['DialogModelingPreprocessor']



+ 4
- 4
modelscope/preprocessors/space/fields/gen_field.py View File

@@ -8,10 +8,10 @@ from itertools import chain

import numpy as np

from modelscope.preprocessors.space.tokenizer import Tokenizer
from modelscope.utils.nlp.space import ontology, utils
from modelscope.utils.nlp.space.db_ops import MultiWozDB
from modelscope.utils.nlp.space.utils import list2np
from ..tokenizer import Tokenizer
from ....utils.nlp.space import ontology, utils
from ....utils.nlp.space.db_ops import MultiWozDB
from ....utils.nlp.space.utils import list2np


class BPETextField(object):


+ 4
- 4
modelscope/preprocessors/space/fields/intent_field.py View File

@@ -14,10 +14,10 @@ import json
import numpy as np
from tqdm import tqdm

from modelscope.preprocessors.space.tokenizer import Tokenizer
from modelscope.utils.nlp.space import ontology, utils
from modelscope.utils.nlp.space.scores import hierarchical_set_score
from modelscope.utils.nlp.space.utils import list2np
from ..tokenizer import Tokenizer
from ....utils.nlp.space import ontology, utils
from ....utils.nlp.space.scores import hierarchical_set_score
from ....utils.nlp.space.utils import list2np


class BPETextField(object):


+ 1
- 1
modelscope/trainers/nlp/space/trainers/gen_trainer.py View File

@@ -13,7 +13,7 @@ import torch
from tqdm import tqdm
from transformers.optimization import AdamW, get_linear_schedule_with_warmup

import modelscope.utils.nlp.space.ontology as ontology
from .....utils.nlp.space import ontology
from ..metrics.metrics_tracker import MetricsTracker




+ 1
- 2
modelscope/trainers/nlp/space/trainers/intent_trainer.py View File

@@ -14,9 +14,8 @@ import torch
from tqdm import tqdm
from transformers.optimization import AdamW, get_linear_schedule_with_warmup

from modelscope.trainers.nlp.space.metrics.metrics_tracker import \
from ..metrics.metrics_tracker import \
MetricsTracker
from modelscope.utils.nlp.space.args import str2bool


def get_logger(log_path, name='default'):


Loading…
Cancel
Save