Browse Source

[to #42322933]830NLP 篇章排序/文本表示模型代码check

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9856179
master
dingkun.ldk yingda.chen 3 years ago
parent
commit
54e1a6d88b
21 changed files with 1035 additions and 18 deletions
  1. +6
    -0
      modelscope/metainfo.py
  2. +4
    -0
      modelscope/models/nlp/__init__.py
  3. +78
    -0
      modelscope/models/nlp/passage_ranking.py
  4. +74
    -0
      modelscope/models/nlp/sentence_embedding.py
  5. +2
    -0
      modelscope/msdatasets/task_datasets/__init__.py
  6. +151
    -0
      modelscope/msdatasets/task_datasets/passage_ranking_dataset.py
  7. +2
    -7
      modelscope/outputs.py
  8. +5
    -0
      modelscope/pipelines/builder.py
  9. +4
    -1
      modelscope/pipelines/nlp/__init__.py
  10. +58
    -0
      modelscope/pipelines/nlp/passage_ranking_pipeline.py
  11. +60
    -0
      modelscope/pipelines/nlp/sentence_embedding_pipeline.py
  12. +3
    -1
      modelscope/preprocessors/__init__.py
  13. +103
    -1
      modelscope/preprocessors/nlp.py
  14. +2
    -2
      modelscope/trainers/__init__.py
  15. +2
    -0
      modelscope/trainers/nlp/__init__.py
  16. +197
    -0
      modelscope/trainers/nlp/passage_ranking_trainer.py
  17. +6
    -6
      modelscope/trainers/trainer.py
  18. +2
    -0
      modelscope/utils/constant.py
  19. +61
    -0
      tests/pipelines/test_passage_ranking.py
  20. +82
    -0
      tests/pipelines/test_sentence_embedding.py
  21. +133
    -0
      tests/trainers/test_finetune_passage_ranking.py

+ 6
- 0
modelscope/metainfo.py View File

@@ -193,6 +193,8 @@ class Pipelines(object):
plug_generation = 'plug-generation' plug_generation = 'plug-generation'
faq_question_answering = 'faq-question-answering' faq_question_answering = 'faq-question-answering'
conversational_text_to_sql = 'conversational-text-to-sql' conversational_text_to_sql = 'conversational-text-to-sql'
sentence_embedding = 'sentence-embedding'
passage_ranking = 'passage-ranking'
relation_extraction = 'relation-extraction' relation_extraction = 'relation-extraction'
document_segmentation = 'document-segmentation' document_segmentation = 'document-segmentation'


@@ -245,6 +247,7 @@ class Trainers(object):
dialog_intent_trainer = 'dialog-intent-trainer' dialog_intent_trainer = 'dialog-intent-trainer'
nlp_base_trainer = 'nlp-base-trainer' nlp_base_trainer = 'nlp-base-trainer'
nlp_veco_trainer = 'nlp-veco-trainer' nlp_veco_trainer = 'nlp-veco-trainer'
nlp_passage_ranking_trainer = 'nlp-passage-ranking-trainer'


# audio trainers # audio trainers
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
@@ -272,6 +275,7 @@ class Preprocessors(object):


# nlp preprocessor # nlp preprocessor
sen_sim_tokenizer = 'sen-sim-tokenizer' sen_sim_tokenizer = 'sen-sim-tokenizer'
cross_encoder_tokenizer = 'cross-encoder-tokenizer'
bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer'
text_gen_tokenizer = 'text-gen-tokenizer' text_gen_tokenizer = 'text-gen-tokenizer'
token_cls_tokenizer = 'token-cls-tokenizer' token_cls_tokenizer = 'token-cls-tokenizer'
@@ -284,6 +288,8 @@ class Preprocessors(object):
sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer'
zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer'
text_error_correction = 'text-error-correction' text_error_correction = 'text-error-correction'
sentence_embedding = 'sentence-embedding'
passage_ranking = 'passage-ranking'
sequence_labeling_tokenizer = 'sequence-labeling-tokenizer' sequence_labeling_tokenizer = 'sequence-labeling-tokenizer'
word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor'
fill_mask = 'fill-mask' fill_mask = 'fill-mask'


+ 4
- 0
modelscope/models/nlp/__init__.py View File

@@ -29,6 +29,8 @@ if TYPE_CHECKING:
SingleBackboneTaskModelBase, SingleBackboneTaskModelBase,
TokenClassificationModel) TokenClassificationModel)
from .token_classification import SbertForTokenClassification from .token_classification import SbertForTokenClassification
from .sentence_embedding import SentenceEmbedding
from .passage_ranking import PassageRanking


else: else:
_import_structure = { _import_structure = {
@@ -62,6 +64,8 @@ else:
'SingleBackboneTaskModelBase', 'TokenClassificationModel' 'SingleBackboneTaskModelBase', 'TokenClassificationModel'
], ],
'token_classification': ['SbertForTokenClassification'], 'token_classification': ['SbertForTokenClassification'],
'sentence_embedding': ['SentenceEmbedding'],
'passage_ranking': ['PassageRanking'],
} }


import sys import sys


+ 78
- 0
modelscope/models/nlp/passage_ranking.py View File

@@ -0,0 +1,78 @@
from typing import Any, Dict

import numpy as np
import torch

from modelscope.metainfo import Models
from modelscope.models import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.nlp import SbertForSequenceClassification
from modelscope.models.nlp.structbert import SbertPreTrainedModel
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import Tasks

__all__ = ['PassageRanking']


@MODELS.register_module(Tasks.passage_ranking, module_name=Models.bert)
class PassageRanking(SbertForSequenceClassification, SbertPreTrainedModel):
base_model_prefix: str = 'bert'
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r'position_ids']

def __init__(self, config, model_dir, *args, **kwargs):
if hasattr(config, 'base_model_prefix'):
PassageRanking.base_model_prefix = config.base_model_prefix
super().__init__(config, model_dir)
self.train_batch_size = kwargs.get('train_batch_size', 4)
self.register_buffer(
'target_label',
torch.zeros(self.train_batch_size, dtype=torch.long))

def build_base_model(self):
from .structbert import SbertModel
return SbertModel(self.config, add_pooling_layer=True)

def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
outputs = self.base_model.forward(**input)

# backbone model should return pooled_output as its second output
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
if self.base_model.training:
scores = logits.view(self.train_batch_size, -1)
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(scores, self.target_label)
return {OutputKeys.LOGITS: logits, OutputKeys.LOSS: loss}
return {OutputKeys.LOGITS: logits}

def sigmoid(self, logits):
return np.exp(logits) / (1 + np.exp(logits))

def postprocess(self, inputs: Dict[str, np.ndarray],
**kwargs) -> Dict[str, np.ndarray]:
logits = inputs['logits'].squeeze(-1).detach().cpu().numpy()
logits = self.sigmoid(logits).tolist()
result = {OutputKeys.SCORES: logits}
return result

@classmethod
def _instantiate(cls, **kwargs):
"""Instantiate the model.

@param kwargs: Input args.
model_dir: The model dir used to load the checkpoint and the label information.
num_labels: An optional arg to tell the model how many classes to initialize.
Method will call utils.parse_label_mapping if num_labels not supplied.
If num_labels is not found, the model will use the default setting (1 classes).
@return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
"""

num_labels = kwargs.get('num_labels', 1)
model_args = {} if num_labels is None else {'num_labels': num_labels}

return super(SbertPreTrainedModel, PassageRanking).from_pretrained(
pretrained_model_name_or_path=kwargs.get('model_dir'),
model_dir=kwargs.get('model_dir'),
**model_args)

+ 74
- 0
modelscope/models/nlp/sentence_embedding.py View File

@@ -0,0 +1,74 @@
import os
from typing import Any, Dict

import json
import numpy as np

from modelscope.metainfo import Models
from modelscope.models import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.nlp.structbert import SbertPreTrainedModel
from modelscope.utils.constant import Tasks

__all__ = ['SentenceEmbedding']


@MODELS.register_module(Tasks.sentence_embedding, module_name=Models.bert)
class SentenceEmbedding(TorchModel, SbertPreTrainedModel):
base_model_prefix: str = 'bert'
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r'position_ids']

def __init__(self, config, model_dir):
super().__init__(model_dir)
self.config = config
setattr(self, self.base_model_prefix, self.build_base_model())

def build_base_model(self):
from .structbert import SbertModel
return SbertModel(self.config, add_pooling_layer=False)

def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
"""return the result by the model

Args:
input (Dict[str, Any]): the preprocessed data

Returns:
Dict[str, np.ndarray]: results
Example:
{
'predictions': array([1]), # lable 0-negative 1-positive
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
}
"""
return self.base_model(**input)

def postprocess(self, inputs: Dict[str, np.ndarray],
**kwargs) -> Dict[str, np.ndarray]:
embs = inputs['last_hidden_state'][:, 0].cpu().numpy()
num_sent = embs.shape[0]
if num_sent >= 2:
scores = np.dot(embs[0:1, ], np.transpose(embs[1:, ],
(1, 0))).tolist()[0]
else:
scores = []
result = {'text_embedding': embs, 'scores': scores}

return result

@classmethod
def _instantiate(cls, **kwargs):
"""Instantiate the model.

@param kwargs: Input args.
model_dir: The model dir used to load the checkpoint and the label information.
@return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
"""
model_args = {}

return super(SbertPreTrainedModel, SentenceEmbedding).from_pretrained(
pretrained_model_name_or_path=kwargs.get('model_dir'),
model_dir=kwargs.get('model_dir'),
**model_args)

+ 2
- 0
modelscope/msdatasets/task_datasets/__init__.py View File

@@ -11,12 +11,14 @@ if TYPE_CHECKING:
from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset
from .movie_scene_segmentation import MovieSceneSegmentationDataset from .movie_scene_segmentation import MovieSceneSegmentationDataset
from .video_summarization_dataset import VideoSummarizationDataset from .video_summarization_dataset import VideoSummarizationDataset
from .passage_ranking_dataset import PassageRankingDataset


else: else:
_import_structure = { _import_structure = {
'base': ['TaskDataset'], 'base': ['TaskDataset'],
'builder': ['TASK_DATASETS', 'build_task_dataset'], 'builder': ['TASK_DATASETS', 'build_task_dataset'],
'torch_base_dataset': ['TorchTaskDataset'], 'torch_base_dataset': ['TorchTaskDataset'],
'passage_ranking_dataset': ['PassageRankingDataset'],
'veco_dataset': ['VecoDataset'], 'veco_dataset': ['VecoDataset'],
'image_instance_segmentation_coco_dataset': 'image_instance_segmentation_coco_dataset':
['ImageInstanceSegmentationCocoDataset'], ['ImageInstanceSegmentationCocoDataset'],


+ 151
- 0
modelscope/msdatasets/task_datasets/passage_ranking_dataset.py View File

@@ -0,0 +1,151 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import random
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union

import torch
from datasets import Dataset, IterableDataset, concatenate_datasets
from torch.utils.data import ConcatDataset
from transformers import DataCollatorWithPadding

from modelscope.metainfo import Models
from modelscope.utils.constant import ModeKeys, Tasks
from .base import TaskDataset
from .builder import TASK_DATASETS
from .torch_base_dataset import TorchTaskDataset


@TASK_DATASETS.register_module(
group_key=Tasks.passage_ranking, module_name=Models.bert)
class PassageRankingDataset(TorchTaskDataset):

def __init__(self,
datasets: Union[Any, List[Any]],
mode,
preprocessor=None,
*args,
**kwargs):
self.seed = kwargs.get('seed', 42)
self.permutation = None
self.datasets = None
self.dataset_config = kwargs
self.query_sequence = self.dataset_config.get('query_sequence',
'query')
self.pos_sequence = self.dataset_config.get('pos_sequence',
'positive_passages')
self.neg_sequence = self.dataset_config.get('neg_sequence',
'negative_passages')
self.passage_text_fileds = self.dataset_config.get(
'passage_text_fileds', ['title', 'text'])
self.qid_field = self.dataset_config.get('qid_field', 'query_id')
if mode == ModeKeys.TRAIN:
train_config = kwargs.get('train', {})
self.neg_samples = train_config.get('neg_samples', 4)

super().__init__(datasets, mode, preprocessor, **kwargs)

def __getitem__(self, index) -> Any:
if self.mode == ModeKeys.TRAIN:
return self.__get_train_item__(index)
else:
return self.__get_test_item__(index)

def __get_test_item__(self, index):
group = self._inner_dataset[index]
labels = []

qry = group[self.query_sequence]

pos_sequences = group[self.pos_sequence]
pos_sequences = [
' '.join([ele[key] for key in self.passage_text_fileds])
for ele in pos_sequences
]
labels.extend([1] * len(pos_sequences))

neg_sequences = group[self.neg_sequence]
neg_sequences = [
' '.join([ele[key] for key in self.passage_text_fileds])
for ele in neg_sequences
]

labels.extend([0] * len(neg_sequences))
qid = group[self.qid_field]

examples = pos_sequences + neg_sequences
sample = {
'qid': torch.LongTensor([int(qid)] * len(labels)),
self.preprocessor.first_sequence: qry,
self.preprocessor.second_sequence: examples,
'labels': torch.LongTensor(labels)
}
return self.prepare_sample(sample)

def __get_train_item__(self, index):
group = self._inner_dataset[index]

qry = group[self.query_sequence]

pos_sequences = group[self.pos_sequence]
pos_sequences = [
' '.join([ele[key] for key in self.passage_text_fileds])
for ele in pos_sequences
]

neg_sequences = group[self.neg_sequence]
neg_sequences = [
' '.join([ele[key] for key in self.passage_text_fileds])
for ele in neg_sequences
]

pos_psg = random.choice(pos_sequences)

if len(neg_sequences) < self.neg_samples:
negs = random.choices(neg_sequences, k=self.neg_samples)
else:
negs = random.sample(neg_sequences, k=self.neg_samples)
examples = [pos_psg] + negs
sample = {
self.preprocessor.first_sequence: qry,
self.preprocessor.second_sequence: examples,
}
return self.prepare_sample(sample)

def __len__(self):
return len(self._inner_dataset)

def prepare_dataset(self, datasets: Union[Any, List[Any]]) -> Any:
"""Prepare a dataset.

User can process the input datasets in a whole dataset perspective.
This method gives a default implementation of datasets merging, user can override this
method to write custom logics.

Args:
datasets: The original dataset(s)

Returns: A single dataset, which may be created after merging.

"""
if isinstance(datasets, List):
if len(datasets) == 1:
return datasets[0]
elif len(datasets) > 1:
return ConcatDataset(datasets)
else:
return datasets

def prepare_sample(self, data):
"""Preprocess the data fetched from the inner_dataset.

If the preprocessor is None, the original data will be returned, else the preprocessor will be called.
User can override this method to implement custom logics.

Args:
data: The data fetched from the dataset.

Returns: The processed data.

"""
return self.preprocessor(
data) if self.preprocessor is not None else data

+ 2
- 7
modelscope/outputs.py View File

@@ -387,19 +387,14 @@ TASK_OUTPUTS = {
# "output": "我想吃苹果" # "output": "我想吃苹果"
# } # }
Tasks.text_error_correction: [OutputKeys.OUTPUT], Tasks.text_error_correction: [OutputKeys.OUTPUT],

Tasks.sentence_embedding: [OutputKeys.TEXT_EMBEDDING, OutputKeys.SCORES],
Tasks.passage_ranking: [OutputKeys.SCORES],
# text generation result for single sample # text generation result for single sample
# { # {
# "text": "this is the text generated by a model." # "text": "this is the text generated by a model."
# } # }
Tasks.text_generation: [OutputKeys.TEXT], Tasks.text_generation: [OutputKeys.TEXT],


# text feature extraction for single sample
# {
# "text_embedding": np.array with shape [1, D]
# }
Tasks.sentence_embedding: [OutputKeys.TEXT_EMBEDDING],

# fill mask result for single sample # fill mask result for single sample
# { # {
# "text": "this is the text which masks filled by model." # "text": "this is the text which masks filled by model."


+ 5
- 0
modelscope/pipelines/builder.py View File

@@ -17,6 +17,11 @@ PIPELINES = Registry('pipelines')


DEFAULT_MODEL_FOR_PIPELINE = { DEFAULT_MODEL_FOR_PIPELINE = {
# TaskName: (pipeline_module_name, model_repo) # TaskName: (pipeline_module_name, model_repo)
Tasks.sentence_embedding:
(Pipelines.sentence_embedding,
'damo/nlp_corom_sentence-embedding_english-base'),
Tasks.passage_ranking: (Pipelines.passage_ranking,
'damo/nlp_corom_passage-ranking_english-base'),
Tasks.word_segmentation: Tasks.word_segmentation:
(Pipelines.word_segmentation, (Pipelines.word_segmentation,
'damo/nlp_structbert_word-segmentation_chinese-base'), 'damo/nlp_structbert_word-segmentation_chinese-base'),


+ 4
- 1
modelscope/pipelines/nlp/__init__.py View File

@@ -25,7 +25,8 @@ if TYPE_CHECKING:
from .translation_pipeline import TranslationPipeline from .translation_pipeline import TranslationPipeline
from .word_segmentation_pipeline import WordSegmentationPipeline from .word_segmentation_pipeline import WordSegmentationPipeline
from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline

from .passage_ranking_pipeline import PassageRankingPipeline
from .sentence_embedding_pipeline import SentenceEmbeddingPipeline
else: else:
_import_structure = { _import_structure = {
'conversational_text_to_sql_pipeline': 'conversational_text_to_sql_pipeline':
@@ -55,6 +56,8 @@ else:
'word_segmentation_pipeline': ['WordSegmentationPipeline'], 'word_segmentation_pipeline': ['WordSegmentationPipeline'],
'zero_shot_classification_pipeline': 'zero_shot_classification_pipeline':
['ZeroShotClassificationPipeline'], ['ZeroShotClassificationPipeline'],
'passage_ranking_pipeline': ['PassageRankingPipeline'],
'sentence_embedding_pipeline': ['SentenceEmbeddingPipeline']
} }


import sys import sys


+ 58
- 0
modelscope/pipelines/nlp/passage_ranking_pipeline.py View File

@@ -0,0 +1,58 @@
from typing import Any, Dict, Optional, Union

import torch

from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import PassageRankingPreprocessor, Preprocessor
from modelscope.utils.constant import Tasks

__all__ = ['PassageRankingPipeline']


@PIPELINES.register_module(
Tasks.passage_ranking, module_name=Pipelines.passage_ranking)
class PassageRankingPipeline(Pipeline):

def __init__(self,
model: Union[Model, str],
preprocessor: Optional[Preprocessor] = None,
**kwargs):
"""Use `model` and `preprocessor` to create a nlp word segment pipeline for prediction.

Args:
model (str or Model): Supply either a local model dir which supported the WS task,
or a model id from the model hub, or a torch model instance.
preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
the model if supplied.
sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value.
"""
model = model if isinstance(model,
Model) else Model.from_pretrained(model)

if preprocessor is None:
preprocessor = PassageRankingPreprocessor(
model.model_dir if isinstance(model, Model) else model,
sequence_length=kwargs.pop('sequence_length', 128))
model.eval()
super().__init__(model=model, preprocessor=preprocessor, **kwargs)

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
return {**self.model(inputs, **forward_params)}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""process the prediction results
Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, Any]: the predicted text representation
"""
pred_list = inputs[OutputKeys.SCORES]

return {OutputKeys.SCORES: pred_list}

+ 60
- 0
modelscope/pipelines/nlp/sentence_embedding_pipeline.py View File

@@ -0,0 +1,60 @@
from typing import Any, Dict, Optional, Union

import torch

from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import (Preprocessor,
SentenceEmbeddingPreprocessor)
from modelscope.utils.constant import Tasks

__all__ = ['SentenceEmbeddingPipeline']


@PIPELINES.register_module(
Tasks.sentence_embedding, module_name=Pipelines.sentence_embedding)
class SentenceEmbeddingPipeline(Pipeline):

def __init__(self,
model: Union[Model, str],
preprocessor: Optional[Preprocessor] = None,
first_sequence='first_sequence',
**kwargs):
"""Use `model` and `preprocessor` to create a nlp text dual encoder then generates the text representation.
Args:
model (str or Model): Supply either a local model dir which supported the WS task,
or a model id from the model hub, or a torch model instance.
preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
the model if supplied.
sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value.
"""
model = model if isinstance(model,
Model) else Model.from_pretrained(model)
if preprocessor is None:
preprocessor = SentenceEmbeddingPreprocessor(
model.model_dir if isinstance(model, Model) else model,
first_sequence=first_sequence,
sequence_length=kwargs.pop('sequence_length', 128))
model.eval()
super().__init__(model=model, preprocessor=preprocessor, **kwargs)

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
return {**self.model(inputs, **forward_params)}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""process the prediction results

Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, Any]: the predicted text representation
"""
embs = inputs[OutputKeys.TEXT_EMBEDDING]
scores = inputs[OutputKeys.SCORES]
return {OutputKeys.TEXT_EMBEDDING: embs, OutputKeys.SCORES: scores}

+ 3
- 1
modelscope/preprocessors/__init__.py View File

@@ -23,7 +23,8 @@ if TYPE_CHECKING:
ZeroShotClassificationPreprocessor, NERPreprocessor, ZeroShotClassificationPreprocessor, NERPreprocessor,
TextErrorCorrectionPreprocessor, FaqQuestionAnsweringPreprocessor, TextErrorCorrectionPreprocessor, FaqQuestionAnsweringPreprocessor,
SequenceLabelingPreprocessor, RelationExtractionPreprocessor, SequenceLabelingPreprocessor, RelationExtractionPreprocessor,
DocumentSegmentationPreprocessor, FillMaskPoNetPreprocessor)
DocumentSegmentationPreprocessor, FillMaskPoNetPreprocessor,
PassageRankingPreprocessor)
from .space import (DialogIntentPredictionPreprocessor, from .space import (DialogIntentPredictionPreprocessor,
DialogModelingPreprocessor, DialogModelingPreprocessor,
DialogStateTrackingPreprocessor) DialogStateTrackingPreprocessor)
@@ -50,6 +51,7 @@ else:
'SingleSentenceClassificationPreprocessor', 'SingleSentenceClassificationPreprocessor',
'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor', 'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor',
'ZeroShotClassificationPreprocessor', 'NERPreprocessor', 'ZeroShotClassificationPreprocessor', 'NERPreprocessor',
'SentenceEmbeddingPreprocessor', 'PassageRankingPreprocessor',
'TextErrorCorrectionPreprocessor', 'TextErrorCorrectionPreprocessor',
'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor', 'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor',
'RelationExtractionPreprocessor', 'RelationExtractionPreprocessor',


+ 103
- 1
modelscope/preprocessors/nlp.py View File

@@ -29,6 +29,7 @@ __all__ = [
'PairSentenceClassificationPreprocessor', 'PairSentenceClassificationPreprocessor',
'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor', 'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor',
'ZeroShotClassificationPreprocessor', 'NERPreprocessor', 'ZeroShotClassificationPreprocessor', 'NERPreprocessor',
'SentenceEmbeddingPreprocessor', 'PassageRankingPreprocessor',
'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor', 'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor',
'SequenceLabelingPreprocessor', 'RelationExtractionPreprocessor', 'SequenceLabelingPreprocessor', 'RelationExtractionPreprocessor',
'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor' 'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor'
@@ -100,6 +101,7 @@ class SequenceClassificationPreprocessor(Preprocessor):


text_a = new_data[self.first_sequence] text_a = new_data[self.first_sequence]
text_b = new_data.get(self.second_sequence, None) text_b = new_data.get(self.second_sequence, None)

feature = self.tokenizer( feature = self.tokenizer(
text_a, text_a,
text_b, text_b,
@@ -111,7 +113,6 @@ class SequenceClassificationPreprocessor(Preprocessor):
rst['input_ids'].append(feature['input_ids']) rst['input_ids'].append(feature['input_ids'])
rst['attention_mask'].append(feature['attention_mask']) rst['attention_mask'].append(feature['attention_mask'])
rst['token_type_ids'].append(feature['token_type_ids']) rst['token_type_ids'].append(feature['token_type_ids'])

return rst return rst




@@ -268,6 +269,62 @@ class NLPTokenizerPreprocessorBase(Preprocessor):
output[OutputKeys.LABELS] = labels output[OutputKeys.LABELS] = labels




@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.passage_ranking)
class PassageRankingPreprocessor(NLPTokenizerPreprocessorBase):
"""The tokenizer preprocessor used in passage ranking model.
"""

def __init__(self,
model_dir: str,
mode=ModeKeys.INFERENCE,
*args,
**kwargs):
"""preprocess the data

Args:
model_dir (str): model path
"""
super().__init__(model_dir, pair=True, mode=mode, *args, **kwargs)
self.model_dir: str = model_dir
self.first_sequence: str = kwargs.pop('first_sequence',
'source_sentence')
self.second_sequence = kwargs.pop('second_sequence',
'sentences_to_compare')
self.sequence_length = kwargs.pop('sequence_length', 128)

self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)

@type_assert(object, (str, tuple, Dict))
def __call__(self, data: Union[tuple, Dict]) -> Dict[str, Any]:
if isinstance(data, tuple):
sentence1, sentence2 = data
elif isinstance(data, dict):
sentence1 = data.get(self.first_sequence)
sentence2 = data.get(self.second_sequence)
if isinstance(sentence2, str):
sentence2 = [sentence2]
if isinstance(sentence1, str):
sentence1 = [sentence1]
sentence1 = sentence1 * len(sentence2)

max_seq_length = self.sequence_length
feature = self.tokenizer(
sentence1,
sentence2,
padding='max_length',
truncation=True,
max_length=max_seq_length,
return_tensors='pt')
if 'labels' in data:
labels = data['labels']
feature['labels'] = labels
if 'qid' in data:
qid = data['qid']
feature['qid'] = qid
return feature


@PREPROCESSORS.register_module( @PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.nli_tokenizer) Fields.nlp, module_name=Preprocessors.nli_tokenizer)
@PREPROCESSORS.register_module( @PREPROCESSORS.register_module(
@@ -298,6 +355,51 @@ class SingleSentenceClassificationPreprocessor(NLPTokenizerPreprocessorBase):
super().__init__(model_dir, pair=False, mode=mode, **kwargs) super().__init__(model_dir, pair=False, mode=mode, **kwargs)




@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.sentence_embedding)
class SentenceEmbeddingPreprocessor(NLPTokenizerPreprocessorBase):
"""The tokenizer preprocessor used in sentence embedding.
"""

def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs):
kwargs['truncation'] = kwargs.get('truncation', True)
kwargs['padding'] = kwargs.get(
'padding', False if mode == ModeKeys.INFERENCE else 'max_length')
kwargs['max_length'] = kwargs.pop('sequence_length', 128)
super().__init__(model_dir, pair=False, mode=mode, **kwargs)

def __call__(self, data: Union[str, Dict]) -> Dict[str, Any]:
"""process the raw input data

Args:
data Dict:
keys: "source_sentence" && "sentences_to_compare"
values: list of sentences
Example:
{"source_sentence": ["how long it take to get a master's degree"],
"sentences_to_compare": ["On average, students take about 18 to 24 months
to complete a master's degree.",
"On the other hand, some students prefer to go at a slower pace
and choose to take several years to complete their studies.",
"It can take anywhere from two semesters"]}
Returns:
Dict[str, Any]: the preprocessed data
"""
source_sentence = data['source_sentence']
compare_sentences = data['sentences_to_compare']
sentences = []
sentences.append(source_sentence[0])
for sent in compare_sentences:
sentences.append(sent)

tokenized_inputs = self.tokenizer(
sentences,
return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None,
padding=True,
truncation=True)
return tokenized_inputs


@PREPROCESSORS.register_module( @PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.zero_shot_cls_tokenizer) Fields.nlp, module_name=Preprocessors.zero_shot_cls_tokenizer)
class ZeroShotClassificationPreprocessor(NLPTokenizerPreprocessorBase): class ZeroShotClassificationPreprocessor(NLPTokenizerPreprocessorBase):


+ 2
- 2
modelscope/trainers/__init__.py View File

@@ -11,7 +11,7 @@ if TYPE_CHECKING:
ImagePortraitEnhancementTrainer, ImagePortraitEnhancementTrainer,
MovieSceneSegmentationTrainer) MovieSceneSegmentationTrainer)
from .multi_modal import CLIPTrainer from .multi_modal import CLIPTrainer
from .nlp import SequenceClassificationTrainer
from .nlp import SequenceClassificationTrainer, PassageRankingTrainer
from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer
from .trainer import EpochBasedTrainer from .trainer import EpochBasedTrainer


@@ -25,7 +25,7 @@ else:
'ImagePortraitEnhancementTrainer', 'MovieSceneSegmentationTrainer' 'ImagePortraitEnhancementTrainer', 'MovieSceneSegmentationTrainer'
], ],
'multi_modal': ['CLIPTrainer'], 'multi_modal': ['CLIPTrainer'],
'nlp': ['SequenceClassificationTrainer'],
'nlp': ['SequenceClassificationTrainer', 'PassageRankingTrainer'],
'nlp_trainer': ['NlpEpochBasedTrainer', 'VecoTrainer'], 'nlp_trainer': ['NlpEpochBasedTrainer', 'VecoTrainer'],
'trainer': ['EpochBasedTrainer'] 'trainer': ['EpochBasedTrainer']
} }


+ 2
- 0
modelscope/trainers/nlp/__init__.py View File

@@ -6,10 +6,12 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING: if TYPE_CHECKING:
from .sequence_classification_trainer import SequenceClassificationTrainer from .sequence_classification_trainer import SequenceClassificationTrainer
from .csanmt_translation_trainer import CsanmtTranslationTrainer from .csanmt_translation_trainer import CsanmtTranslationTrainer
from .passage_ranking_trainer import PassageRankingTranier
else: else:
_import_structure = { _import_structure = {
'sequence_classification_trainer': ['SequenceClassificationTrainer'], 'sequence_classification_trainer': ['SequenceClassificationTrainer'],
'csanmt_translation_trainer': ['CsanmtTranslationTrainer'], 'csanmt_translation_trainer': ['CsanmtTranslationTrainer'],
'passage_ranking_trainer': ['PassageRankingTrainer']
} }


import sys import sys


+ 197
- 0
modelscope/trainers/nlp/passage_ranking_trainer.py View File

@@ -0,0 +1,197 @@
import time
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

from modelscope.metainfo import Trainers
from modelscope.models.base import Model, TorchModel
from modelscope.msdatasets.ms_dataset import MsDataset
from modelscope.preprocessors.base import Preprocessor
from modelscope.trainers.base import BaseTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
from modelscope.utils.logger import get_logger

logger = get_logger()


@dataclass
class GroupCollator():
"""
Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg]
and pass batch separately to the actual collator.
Abstract out data detail for the model.
"""

def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
if isinstance(features[0], list):
features = sum(features, [])
keys = features[0].keys()
batch = {k: list() for k in keys}
for ele in features:
for k, v in ele.items():
batch[k].append(v)
batch = {k: torch.cat(v, dim=0) for k, v in batch.items()}
return batch


@TRAINERS.register_module(module_name=Trainers.nlp_passage_ranking_trainer)
class PassageRankingTrainer(NlpEpochBasedTrainer):

def __init__(
self,
model: Optional[Union[TorchModel, nn.Module, str]] = None,
cfg_file: Optional[str] = None,
cfg_modify_fn: Optional[Callable] = None,
arg_parse_fn: Optional[Callable] = None,
data_collator: Optional[Callable] = None,
train_dataset: Optional[Union[MsDataset, Dataset]] = None,
eval_dataset: Optional[Union[MsDataset, Dataset]] = None,
preprocessor: Optional[Preprocessor] = None,
optimizers: Tuple[torch.optim.Optimizer,
torch.optim.lr_scheduler._LRScheduler] = (None,
None),
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
**kwargs):

if data_collator is None:
data_collator = GroupCollator()

super().__init__(
model=model,
cfg_file=cfg_file,
cfg_modify_fn=cfg_modify_fn,
arg_parse_fn=arg_parse_fn,
data_collator=data_collator,
preprocessor=preprocessor,
optimizers=optimizers,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model_revision=model_revision,
**kwargs)

def compute_mrr(self, result, k=10):
mrr = 0
for res in result.values():
sorted_res = sorted(res, key=lambda x: x[0], reverse=True)
ar = 0
for index, ele in enumerate(sorted_res[:k]):
if str(ele[1]) == '1':
ar = 1.0 / (index + 1)
break
mrr += ar
return mrr / len(result)

def compute_ndcg(self, result, k=10):
ndcg = 0
from sklearn import ndcg_score
for res in result.values():
sorted_res = sorted(res, key=lambda x: [0], reverse=True)
labels = np.array([[ele[1] for ele in sorted_res]])
scores = np.array([[ele[0] for ele in sorted_res]])
ndcg += float(ndcg_score(labels, scores, k=k))
ndcg = ndcg / len(result)
return ndcg

def evaluate(self,
checkpoint_path: Optional[str] = None,
*args,
**kwargs) -> Dict[str, float]:
"""evaluate a dataset

evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path`
does not exist, read from the config file.

Args:
checkpoint_path (Optional[str], optional): the model path. Defaults to None.

Returns:
Dict[str, float]: the results about the evaluation
Example:
{"accuracy": 0.5091743119266054, "f1": 0.673780487804878}
"""
from modelscope.models.nlp import PassageRanking
# get the raw online dataset
self.eval_dataloader = self._build_dataloader_with_dataset(
self.eval_dataset,
**self.cfg.evaluation.get('dataloader', {}),
collate_fn=self.eval_data_collator)
# generate a standard dataloader
# generate a model
if checkpoint_path is not None:
model = PassageRanking.from_pretrained(checkpoint_path)
else:
model = self.model

# copy from easynlp (start)
model.eval()
total_samples = 0

logits_list = list()
label_list = list()
qid_list = list()

total_spent_time = 0.0
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model.to(device)
for _step, batch in enumerate(self.eval_dataloader):
try:
batch = {
key:
val.to(device) if isinstance(val, torch.Tensor) else val
for key, val in batch.items()
}
except RuntimeError:
batch = {key: val for key, val in batch.items()}

infer_start_time = time.time()
with torch.no_grad():
label_ids = batch.pop('labels').detach().cpu().numpy()
qids = batch.pop('qid').detach().cpu().numpy()
outputs = model(batch)
infer_end_time = time.time()
total_spent_time += infer_end_time - infer_start_time
total_samples += self.eval_dataloader.batch_size

assert 'scores' in outputs
logits = outputs['scores']

label_list.extend(label_ids)
logits_list.extend(logits)
qid_list.extend(qids)

logger.info('Inference time = {:.2f}s, [{:.4f} ms / sample] '.format(
total_spent_time, total_spent_time * 1000 / total_samples))

rank_result = {}
for qid, score, label in zip(qid_list, logits_list, label_list):
if qid not in rank_result:
rank_result[qid] = []
rank_result[qid].append((score, label))

for qid in rank_result:
rank_result[qid] = sorted(rank_result[qid], key=lambda x: x[0])

eval_outputs = list()
for metric in self.metrics:
if metric.startswith('mrr'):
k = metric.split('@')[-1]
k = int(k)
mrr = self.compute_mrr(rank_result, k=k)
logger.info('{}: {}'.format(metric, mrr))
eval_outputs.append((metric, mrr))
elif metric.startswith('ndcg'):
k = metric.split('@')[-1]
k = int(k)
ndcg = self.compute_ndcg(rank_result, k=k)
logger.info('{}: {}'.format(metric, ndcg))
eval_outputs.append(('ndcg', ndcg))
else:
raise NotImplementedError('Metric %s not implemented' % metric)

return dict(eval_outputs)

+ 6
- 6
modelscope/trainers/trainer.py View File

@@ -345,12 +345,12 @@ class EpochBasedTrainer(BaseTrainer):
type=self.cfg.task, mode=mode, datasets=datasets) type=self.cfg.task, mode=mode, datasets=datasets)
return build_task_dataset(cfg, self.cfg.task) return build_task_dataset(cfg, self.cfg.task)
else: else:
task_data_config.update(
dict(
mode=mode,
datasets=datasets,
preprocessor=preprocessor))
return build_task_dataset(task_data_config, self.cfg.task)
# avoid add no str value datasets, preprocessors in cfg
task_data_build_config = ConfigDict(
mode=mode, datasets=datasets, preprocessor=preprocessor)
task_data_build_config.update(task_data_config)
return build_task_dataset(task_data_build_config,
self.cfg.task)
except Exception: except Exception:
if isinstance(datasets, (List, Tuple)) or preprocessor is not None: if isinstance(datasets, (List, Tuple)) or preprocessor is not None:
return TorchTaskDataset( return TorchTaskDataset(


+ 2
- 0
modelscope/utils/constant.py View File

@@ -89,6 +89,8 @@ class NLPTasks(object):
sentiment_analysis = 'sentiment-analysis' sentiment_analysis = 'sentiment-analysis'
sentence_similarity = 'sentence-similarity' sentence_similarity = 'sentence-similarity'
text_classification = 'text-classification' text_classification = 'text-classification'
sentence_embedding = 'sentence-embedding'
passage_ranking = 'passage-ranking'
relation_extraction = 'relation-extraction' relation_extraction = 'relation-extraction'
zero_shot = 'zero-shot' zero_shot = 'zero-shot'
translation = 'translation' translation = 'translation'


+ 61
- 0
tests/pipelines/test_passage_ranking.py View File

@@ -0,0 +1,61 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import shutil
import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import PassageRanking
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import PassageRankingPipeline
from modelscope.preprocessors import PassageRankingPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class PassageRankingTest(unittest.TestCase):
model_id = 'damo/nlp_corom_passage-ranking_english-base'
inputs = {
'source_sentence': ["how long it take to get a master's degree"],
'sentences_to_compare': [
"On average, students take about 18 to 24 months to complete a master's degree.",
'On the other hand, some students prefer to go at a slower pace and choose to take '
'several years to complete their studies.',
'It can take anywhere from two semesters'
]
}

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)
tokenizer = PassageRankingPreprocessor(cache_path)
model = PassageRanking.from_pretrained(cache_path)
pipeline1 = PassageRankingPipeline(model, preprocessor=tokenizer)
pipeline2 = pipeline(
Tasks.passage_ranking, model=model, preprocessor=tokenizer)
print(f'sentence: {self.inputs}\n'
f'pipeline1:{pipeline1(input=self.inputs)}')
print()
print(f'pipeline2: {pipeline2(input=self.inputs)}')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
tokenizer = PassageRankingPreprocessor(model.model_dir)
pipeline_ins = pipeline(
task=Tasks.passage_ranking, model=model, preprocessor=tokenizer)
print(pipeline_ins(input=self.inputs))

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.passage_ranking, model=self.model_id)
print(pipeline_ins(input=self.inputs))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.passage_ranking)
print(pipeline_ins(input=self.inputs))


if __name__ == '__main__':
unittest.main()

+ 82
- 0
tests/pipelines/test_sentence_embedding.py View File

@@ -0,0 +1,82 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import shutil
import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import SentenceEmbedding
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import SentenceEmbeddingPipeline
from modelscope.preprocessors import SentenceEmbeddingPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class SentenceEmbeddingTest(unittest.TestCase):
model_id = 'damo/nlp_corom_sentence-embedding_english-base'
inputs = {
'source_sentence': ["how long it take to get a master's degree"],
'sentences_to_compare': [
"On average, students take about 18 to 24 months to complete a master's degree.",
'On the other hand, some students prefer to go at a slower pace and choose to take ',
'several years to complete their studies.',
'It can take anywhere from two semesters'
]
}

inputs2 = {
'source_sentence': ["how long it take to get a master's degree"],
'sentences_to_compare': [
"On average, students take about 18 to 24 months to complete a master's degree."
]
}

inputs3 = {
'source_sentence': ["how long it take to get a master's degree"],
'sentences_to_compare': []
}

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)
tokenizer = SentenceEmbeddingPreprocessor(cache_path)
model = SentenceEmbedding.from_pretrained(cache_path)
pipeline1 = SentenceEmbeddingPipeline(model, preprocessor=tokenizer)
pipeline2 = pipeline(
Tasks.sentence_embedding, model=model, preprocessor=tokenizer)
print(f'inputs: {self.inputs}\n'
f'pipeline1:{pipeline1(input=self.inputs)}')
print()
print(f'pipeline2: {pipeline2(input=self.inputs)}')
print()
print(f'inputs: {self.inputs2}\n'
f'pipeline1:{pipeline1(input=self.inputs2)}')
print()
print(f'pipeline2: {pipeline2(input=self.inputs2)}')
print(f'inputs: {self.inputs3}\n'
f'pipeline1:{pipeline1(input=self.inputs3)}')
print()
print(f'pipeline2: {pipeline2(input=self.inputs3)}')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
tokenizer = SentenceEmbeddingPreprocessor(model.model_dir)
pipeline_ins = pipeline(
task=Tasks.sentence_embedding, model=model, preprocessor=tokenizer)
print(pipeline_ins(input=self.inputs))

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.sentence_embedding, model=self.model_id)
print(pipeline_ins(input=self.inputs))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.sentence_embedding)
print(pipeline_ins(input=self.inputs))


if __name__ == '__main__':
unittest.main()

+ 133
- 0
tests/trainers/test_finetune_passage_ranking.py View File

@@ -0,0 +1,133 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union

import torch
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from modelscope.metainfo import Trainers
from modelscope.models import Model
from modelscope.msdatasets import MsDataset
from modelscope.pipelines import pipeline
from modelscope.trainers import build_trainer
from modelscope.utils.constant import ModelFile, Tasks


class TestFinetuneSequenceClassification(unittest.TestCase):
inputs = {
'source_sentence': ["how long it take to get a master's degree"],
'sentences_to_compare': [
"On average, students take about 18 to 24 months to complete a master's degree.",
'On the other hand, some students prefer to go at a slower pace and choose to take '
'several years to complete their studies.',
'It can take anywhere from two semesters'
]
}

def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)

def tearDown(self):
shutil.rmtree(self.tmp_dir)
super().tearDown()

def finetune(self,
model_id,
train_dataset,
eval_dataset,
name=Trainers.nlp_passage_ranking_trainer,
cfg_modify_fn=None,
**kwargs):
kwargs = dict(
model=model_id,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
work_dir=self.tmp_dir,
cfg_modify_fn=cfg_modify_fn,
**kwargs)

os.environ['LOCAL_RANK'] = '0'
trainer = build_trainer(name=name, default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)

def test_finetune_msmarco(self):

def cfg_modify_fn(cfg):
cfg.task = 'passage-ranking'
cfg['preprocessor'] = {'type': 'passage-ranking'}
cfg.train.optimizer.lr = 2e-5
cfg['dataset'] = {
'train': {
'type': 'bert',
'query_sequence': 'query',
'pos_sequence': 'positive_passages',
'neg_sequence': 'negative_passages',
'passage_text_fileds': ['title', 'text'],
'qid_field': 'query_id'
},
'val': {
'type': 'bert',
'query_sequence': 'query',
'pos_sequence': 'positive_passages',
'neg_sequence': 'negative_passages',
'passage_text_fileds': ['title', 'text'],
'qid_field': 'query_id'
},
}
cfg['train']['neg_samples'] = 4
cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30
cfg.train.max_epochs = 1
cfg.train.train_batch_size = 4
cfg.train.lr_scheduler = {
'type': 'LinearLR',
'start_factor': 1.0,
'end_factor': 0.0,
'options': {
'by_epoch': False
}
}
cfg.train.hooks = [{
'type': 'CheckpointHook',
'interval': 1
}, {
'type': 'TextLoggerHook',
'interval': 1
}, {
'type': 'IterTimerHook'
}, {
'type': 'EvaluationHook',
'by_epoch': False,
'interval': 3000
}]
return cfg

# load dataset
ds = MsDataset.load('passage-ranking-demo', 'zyznull')
train_ds = ds['train'].to_hf_dataset()
dev_ds = ds['train'].to_hf_dataset()

self.finetune(
model_id='damo/nlp_corom_passage-ranking_english-base',
train_dataset=train_ds,
eval_dataset=dev_ds,
cfg_modify_fn=cfg_modify_fn)

output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)
self.pipeline_passage_ranking(output_dir)

def pipeline_passage_ranking(self, model_dir):
model = Model.from_pretrained(model_dir)
pipeline_ins = pipeline(task=Tasks.passage_ranking, model=model)
print(pipeline_ins(input=self.inputs))


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save