diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 63b4f1c2..e5c3873b 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -193,6 +193,8 @@ class Pipelines(object): plug_generation = 'plug-generation' faq_question_answering = 'faq-question-answering' conversational_text_to_sql = 'conversational-text-to-sql' + sentence_embedding = 'sentence-embedding' + passage_ranking = 'passage-ranking' relation_extraction = 'relation-extraction' document_segmentation = 'document-segmentation' @@ -245,6 +247,7 @@ class Trainers(object): dialog_intent_trainer = 'dialog-intent-trainer' nlp_base_trainer = 'nlp-base-trainer' nlp_veco_trainer = 'nlp-veco-trainer' + nlp_passage_ranking_trainer = 'nlp-passage-ranking-trainer' # audio trainers speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' @@ -272,6 +275,7 @@ class Preprocessors(object): # nlp preprocessor sen_sim_tokenizer = 'sen-sim-tokenizer' + cross_encoder_tokenizer = 'cross-encoder-tokenizer' bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' text_gen_tokenizer = 'text-gen-tokenizer' token_cls_tokenizer = 'token-cls-tokenizer' @@ -284,6 +288,8 @@ class Preprocessors(object): sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' text_error_correction = 'text-error-correction' + sentence_embedding = 'sentence-embedding' + passage_ranking = 'passage-ranking' sequence_labeling_tokenizer = 'sequence-labeling-tokenizer' word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' fill_mask = 'fill-mask' diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index a3a12c22..d411f1fb 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -29,6 +29,8 @@ if TYPE_CHECKING: SingleBackboneTaskModelBase, TokenClassificationModel) from .token_classification import SbertForTokenClassification + from .sentence_embedding import SentenceEmbedding + from .passage_ranking import PassageRanking else: _import_structure = { @@ -62,6 +64,8 @@ else: 'SingleBackboneTaskModelBase', 'TokenClassificationModel' ], 'token_classification': ['SbertForTokenClassification'], + 'sentence_embedding': ['SentenceEmbedding'], + 'passage_ranking': ['PassageRanking'], } import sys diff --git a/modelscope/models/nlp/passage_ranking.py b/modelscope/models/nlp/passage_ranking.py new file mode 100644 index 00000000..68bca231 --- /dev/null +++ b/modelscope/models/nlp/passage_ranking.py @@ -0,0 +1,78 @@ +from typing import Any, Dict + +import numpy as np +import torch + +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.nlp import SbertForSequenceClassification +from modelscope.models.nlp.structbert import SbertPreTrainedModel +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks + +__all__ = ['PassageRanking'] + + +@MODELS.register_module(Tasks.passage_ranking, module_name=Models.bert) +class PassageRanking(SbertForSequenceClassification, SbertPreTrainedModel): + base_model_prefix: str = 'bert' + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r'position_ids'] + + def __init__(self, config, model_dir, *args, **kwargs): + if hasattr(config, 'base_model_prefix'): + PassageRanking.base_model_prefix = config.base_model_prefix + super().__init__(config, model_dir) + self.train_batch_size = kwargs.get('train_batch_size', 4) + self.register_buffer( + 'target_label', + torch.zeros(self.train_batch_size, dtype=torch.long)) + + def build_base_model(self): + from .structbert import SbertModel + return SbertModel(self.config, add_pooling_layer=True) + + def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: + outputs = self.base_model.forward(**input) + + # backbone model should return pooled_output as its second output + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + if self.base_model.training: + scores = logits.view(self.train_batch_size, -1) + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(scores, self.target_label) + return {OutputKeys.LOGITS: logits, OutputKeys.LOSS: loss} + return {OutputKeys.LOGITS: logits} + + def sigmoid(self, logits): + return np.exp(logits) / (1 + np.exp(logits)) + + def postprocess(self, inputs: Dict[str, np.ndarray], + **kwargs) -> Dict[str, np.ndarray]: + logits = inputs['logits'].squeeze(-1).detach().cpu().numpy() + logits = self.sigmoid(logits).tolist() + result = {OutputKeys.SCORES: logits} + return result + + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + @param kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the label information. + num_labels: An optional arg to tell the model how many classes to initialize. + Method will call utils.parse_label_mapping if num_labels not supplied. + If num_labels is not found, the model will use the default setting (1 classes). + @return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained + """ + + num_labels = kwargs.get('num_labels', 1) + model_args = {} if num_labels is None else {'num_labels': num_labels} + + return super(SbertPreTrainedModel, PassageRanking).from_pretrained( + pretrained_model_name_or_path=kwargs.get('model_dir'), + model_dir=kwargs.get('model_dir'), + **model_args) diff --git a/modelscope/models/nlp/sentence_embedding.py b/modelscope/models/nlp/sentence_embedding.py new file mode 100644 index 00000000..955c0e53 --- /dev/null +++ b/modelscope/models/nlp/sentence_embedding.py @@ -0,0 +1,74 @@ +import os +from typing import Any, Dict + +import json +import numpy as np + +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.nlp.structbert import SbertPreTrainedModel +from modelscope.utils.constant import Tasks + +__all__ = ['SentenceEmbedding'] + + +@MODELS.register_module(Tasks.sentence_embedding, module_name=Models.bert) +class SentenceEmbedding(TorchModel, SbertPreTrainedModel): + base_model_prefix: str = 'bert' + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r'position_ids'] + + def __init__(self, config, model_dir): + super().__init__(model_dir) + self.config = config + setattr(self, self.base_model_prefix, self.build_base_model()) + + def build_base_model(self): + from .structbert import SbertModel + return SbertModel(self.config, add_pooling_layer=False) + + def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: + """return the result by the model + + Args: + input (Dict[str, Any]): the preprocessed data + + Returns: + Dict[str, np.ndarray]: results + Example: + { + 'predictions': array([1]), # lable 0-negative 1-positive + 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), + 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value + } + """ + return self.base_model(**input) + + def postprocess(self, inputs: Dict[str, np.ndarray], + **kwargs) -> Dict[str, np.ndarray]: + embs = inputs['last_hidden_state'][:, 0].cpu().numpy() + num_sent = embs.shape[0] + if num_sent >= 2: + scores = np.dot(embs[0:1, ], np.transpose(embs[1:, ], + (1, 0))).tolist()[0] + else: + scores = [] + result = {'text_embedding': embs, 'scores': scores} + + return result + + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + @param kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the label information. + @return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained + """ + model_args = {} + + return super(SbertPreTrainedModel, SentenceEmbedding).from_pretrained( + pretrained_model_name_or_path=kwargs.get('model_dir'), + model_dir=kwargs.get('model_dir'), + **model_args) diff --git a/modelscope/msdatasets/task_datasets/__init__.py b/modelscope/msdatasets/task_datasets/__init__.py index f97ff8b2..e2bf5bc1 100644 --- a/modelscope/msdatasets/task_datasets/__init__.py +++ b/modelscope/msdatasets/task_datasets/__init__.py @@ -11,12 +11,14 @@ if TYPE_CHECKING: from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset from .movie_scene_segmentation import MovieSceneSegmentationDataset from .video_summarization_dataset import VideoSummarizationDataset + from .passage_ranking_dataset import PassageRankingDataset else: _import_structure = { 'base': ['TaskDataset'], 'builder': ['TASK_DATASETS', 'build_task_dataset'], 'torch_base_dataset': ['TorchTaskDataset'], + 'passage_ranking_dataset': ['PassageRankingDataset'], 'veco_dataset': ['VecoDataset'], 'image_instance_segmentation_coco_dataset': ['ImageInstanceSegmentationCocoDataset'], diff --git a/modelscope/msdatasets/task_datasets/passage_ranking_dataset.py b/modelscope/msdatasets/task_datasets/passage_ranking_dataset.py new file mode 100644 index 00000000..517e0d36 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/passage_ranking_dataset.py @@ -0,0 +1,151 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import random +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple, Union + +import torch +from datasets import Dataset, IterableDataset, concatenate_datasets +from torch.utils.data import ConcatDataset +from transformers import DataCollatorWithPadding + +from modelscope.metainfo import Models +from modelscope.utils.constant import ModeKeys, Tasks +from .base import TaskDataset +from .builder import TASK_DATASETS +from .torch_base_dataset import TorchTaskDataset + + +@TASK_DATASETS.register_module( + group_key=Tasks.passage_ranking, module_name=Models.bert) +class PassageRankingDataset(TorchTaskDataset): + + def __init__(self, + datasets: Union[Any, List[Any]], + mode, + preprocessor=None, + *args, + **kwargs): + self.seed = kwargs.get('seed', 42) + self.permutation = None + self.datasets = None + self.dataset_config = kwargs + self.query_sequence = self.dataset_config.get('query_sequence', + 'query') + self.pos_sequence = self.dataset_config.get('pos_sequence', + 'positive_passages') + self.neg_sequence = self.dataset_config.get('neg_sequence', + 'negative_passages') + self.passage_text_fileds = self.dataset_config.get( + 'passage_text_fileds', ['title', 'text']) + self.qid_field = self.dataset_config.get('qid_field', 'query_id') + if mode == ModeKeys.TRAIN: + train_config = kwargs.get('train', {}) + self.neg_samples = train_config.get('neg_samples', 4) + + super().__init__(datasets, mode, preprocessor, **kwargs) + + def __getitem__(self, index) -> Any: + if self.mode == ModeKeys.TRAIN: + return self.__get_train_item__(index) + else: + return self.__get_test_item__(index) + + def __get_test_item__(self, index): + group = self._inner_dataset[index] + labels = [] + + qry = group[self.query_sequence] + + pos_sequences = group[self.pos_sequence] + pos_sequences = [ + ' '.join([ele[key] for key in self.passage_text_fileds]) + for ele in pos_sequences + ] + labels.extend([1] * len(pos_sequences)) + + neg_sequences = group[self.neg_sequence] + neg_sequences = [ + ' '.join([ele[key] for key in self.passage_text_fileds]) + for ele in neg_sequences + ] + + labels.extend([0] * len(neg_sequences)) + qid = group[self.qid_field] + + examples = pos_sequences + neg_sequences + sample = { + 'qid': torch.LongTensor([int(qid)] * len(labels)), + self.preprocessor.first_sequence: qry, + self.preprocessor.second_sequence: examples, + 'labels': torch.LongTensor(labels) + } + return self.prepare_sample(sample) + + def __get_train_item__(self, index): + group = self._inner_dataset[index] + + qry = group[self.query_sequence] + + pos_sequences = group[self.pos_sequence] + pos_sequences = [ + ' '.join([ele[key] for key in self.passage_text_fileds]) + for ele in pos_sequences + ] + + neg_sequences = group[self.neg_sequence] + neg_sequences = [ + ' '.join([ele[key] for key in self.passage_text_fileds]) + for ele in neg_sequences + ] + + pos_psg = random.choice(pos_sequences) + + if len(neg_sequences) < self.neg_samples: + negs = random.choices(neg_sequences, k=self.neg_samples) + else: + negs = random.sample(neg_sequences, k=self.neg_samples) + examples = [pos_psg] + negs + sample = { + self.preprocessor.first_sequence: qry, + self.preprocessor.second_sequence: examples, + } + return self.prepare_sample(sample) + + def __len__(self): + return len(self._inner_dataset) + + def prepare_dataset(self, datasets: Union[Any, List[Any]]) -> Any: + """Prepare a dataset. + + User can process the input datasets in a whole dataset perspective. + This method gives a default implementation of datasets merging, user can override this + method to write custom logics. + + Args: + datasets: The original dataset(s) + + Returns: A single dataset, which may be created after merging. + + """ + if isinstance(datasets, List): + if len(datasets) == 1: + return datasets[0] + elif len(datasets) > 1: + return ConcatDataset(datasets) + else: + return datasets + + def prepare_sample(self, data): + """Preprocess the data fetched from the inner_dataset. + + If the preprocessor is None, the original data will be returned, else the preprocessor will be called. + User can override this method to implement custom logics. + + Args: + data: The data fetched from the dataset. + + Returns: The processed data. + + """ + return self.preprocessor( + data) if self.preprocessor is not None else data diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 37ab3481..8ddeb314 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -387,19 +387,14 @@ TASK_OUTPUTS = { # "output": "我想吃苹果" # } Tasks.text_error_correction: [OutputKeys.OUTPUT], - + Tasks.sentence_embedding: [OutputKeys.TEXT_EMBEDDING, OutputKeys.SCORES], + Tasks.passage_ranking: [OutputKeys.SCORES], # text generation result for single sample # { # "text": "this is the text generated by a model." # } Tasks.text_generation: [OutputKeys.TEXT], - # text feature extraction for single sample - # { - # "text_embedding": np.array with shape [1, D] - # } - Tasks.sentence_embedding: [OutputKeys.TEXT_EMBEDDING], - # fill mask result for single sample # { # "text": "this is the text which masks filled by model." diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index a1f093a3..50313cf7 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -17,6 +17,11 @@ PIPELINES = Registry('pipelines') DEFAULT_MODEL_FOR_PIPELINE = { # TaskName: (pipeline_module_name, model_repo) + Tasks.sentence_embedding: + (Pipelines.sentence_embedding, + 'damo/nlp_corom_sentence-embedding_english-base'), + Tasks.passage_ranking: (Pipelines.passage_ranking, + 'damo/nlp_corom_passage-ranking_english-base'), Tasks.word_segmentation: (Pipelines.word_segmentation, 'damo/nlp_structbert_word-segmentation_chinese-base'), diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index 42dfc972..6f898c0f 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -25,7 +25,8 @@ if TYPE_CHECKING: from .translation_pipeline import TranslationPipeline from .word_segmentation_pipeline import WordSegmentationPipeline from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline - + from .passage_ranking_pipeline import PassageRankingPipeline + from .sentence_embedding_pipeline import SentenceEmbeddingPipeline else: _import_structure = { 'conversational_text_to_sql_pipeline': @@ -55,6 +56,8 @@ else: 'word_segmentation_pipeline': ['WordSegmentationPipeline'], 'zero_shot_classification_pipeline': ['ZeroShotClassificationPipeline'], + 'passage_ranking_pipeline': ['PassageRankingPipeline'], + 'sentence_embedding_pipeline': ['SentenceEmbeddingPipeline'] } import sys diff --git a/modelscope/pipelines/nlp/passage_ranking_pipeline.py b/modelscope/pipelines/nlp/passage_ranking_pipeline.py new file mode 100644 index 00000000..c03e7b93 --- /dev/null +++ b/modelscope/pipelines/nlp/passage_ranking_pipeline.py @@ -0,0 +1,58 @@ +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import PassageRankingPreprocessor, Preprocessor +from modelscope.utils.constant import Tasks + +__all__ = ['PassageRankingPipeline'] + + +@PIPELINES.register_module( + Tasks.passage_ranking, module_name=Pipelines.passage_ranking) +class PassageRankingPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """Use `model` and `preprocessor` to create a nlp word segment pipeline for prediction. + + Args: + model (str or Model): Supply either a local model dir which supported the WS task, + or a model id from the model hub, or a torch model instance. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. + sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value. + """ + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + + if preprocessor is None: + preprocessor = PassageRankingPreprocessor( + model.model_dir if isinstance(model, Model) else model, + sequence_length=kwargs.pop('sequence_length', 128)) + model.eval() + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + return {**self.model(inputs, **forward_params)} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """process the prediction results + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, Any]: the predicted text representation + """ + pred_list = inputs[OutputKeys.SCORES] + + return {OutputKeys.SCORES: pred_list} diff --git a/modelscope/pipelines/nlp/sentence_embedding_pipeline.py b/modelscope/pipelines/nlp/sentence_embedding_pipeline.py new file mode 100644 index 00000000..3ef6d06b --- /dev/null +++ b/modelscope/pipelines/nlp/sentence_embedding_pipeline.py @@ -0,0 +1,60 @@ +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import (Preprocessor, + SentenceEmbeddingPreprocessor) +from modelscope.utils.constant import Tasks + +__all__ = ['SentenceEmbeddingPipeline'] + + +@PIPELINES.register_module( + Tasks.sentence_embedding, module_name=Pipelines.sentence_embedding) +class SentenceEmbeddingPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + first_sequence='first_sequence', + **kwargs): + """Use `model` and `preprocessor` to create a nlp text dual encoder then generates the text representation. + Args: + model (str or Model): Supply either a local model dir which supported the WS task, + or a model id from the model hub, or a torch model instance. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. + sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value. + """ + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = SentenceEmbeddingPreprocessor( + model.model_dir if isinstance(model, Model) else model, + first_sequence=first_sequence, + sequence_length=kwargs.pop('sequence_length', 128)) + model.eval() + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + return {**self.model(inputs, **forward_params)} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, Any]: the predicted text representation + """ + embs = inputs[OutputKeys.TEXT_EMBEDDING] + scores = inputs[OutputKeys.SCORES] + return {OutputKeys.TEXT_EMBEDDING: embs, OutputKeys.SCORES: scores} diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 6012b5ba..212339ae 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -23,7 +23,8 @@ if TYPE_CHECKING: ZeroShotClassificationPreprocessor, NERPreprocessor, TextErrorCorrectionPreprocessor, FaqQuestionAnsweringPreprocessor, SequenceLabelingPreprocessor, RelationExtractionPreprocessor, - DocumentSegmentationPreprocessor, FillMaskPoNetPreprocessor) + DocumentSegmentationPreprocessor, FillMaskPoNetPreprocessor, + PassageRankingPreprocessor) from .space import (DialogIntentPredictionPreprocessor, DialogModelingPreprocessor, DialogStateTrackingPreprocessor) @@ -50,6 +51,7 @@ else: 'SingleSentenceClassificationPreprocessor', 'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor', 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', + 'SentenceEmbeddingPreprocessor', 'PassageRankingPreprocessor', 'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor', 'RelationExtractionPreprocessor', diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index 9137b105..e20adaa6 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -29,6 +29,7 @@ __all__ = [ 'PairSentenceClassificationPreprocessor', 'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor', 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', + 'SentenceEmbeddingPreprocessor', 'PassageRankingPreprocessor', 'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor', 'RelationExtractionPreprocessor', 'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor' @@ -100,6 +101,7 @@ class SequenceClassificationPreprocessor(Preprocessor): text_a = new_data[self.first_sequence] text_b = new_data.get(self.second_sequence, None) + feature = self.tokenizer( text_a, text_b, @@ -111,7 +113,6 @@ class SequenceClassificationPreprocessor(Preprocessor): rst['input_ids'].append(feature['input_ids']) rst['attention_mask'].append(feature['attention_mask']) rst['token_type_ids'].append(feature['token_type_ids']) - return rst @@ -268,6 +269,62 @@ class NLPTokenizerPreprocessorBase(Preprocessor): output[OutputKeys.LABELS] = labels +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.passage_ranking) +class PassageRankingPreprocessor(NLPTokenizerPreprocessorBase): + """The tokenizer preprocessor used in passage ranking model. + """ + + def __init__(self, + model_dir: str, + mode=ModeKeys.INFERENCE, + *args, + **kwargs): + """preprocess the data + + Args: + model_dir (str): model path + """ + super().__init__(model_dir, pair=True, mode=mode, *args, **kwargs) + self.model_dir: str = model_dir + self.first_sequence: str = kwargs.pop('first_sequence', + 'source_sentence') + self.second_sequence = kwargs.pop('second_sequence', + 'sentences_to_compare') + self.sequence_length = kwargs.pop('sequence_length', 128) + + self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) + + @type_assert(object, (str, tuple, Dict)) + def __call__(self, data: Union[tuple, Dict]) -> Dict[str, Any]: + if isinstance(data, tuple): + sentence1, sentence2 = data + elif isinstance(data, dict): + sentence1 = data.get(self.first_sequence) + sentence2 = data.get(self.second_sequence) + if isinstance(sentence2, str): + sentence2 = [sentence2] + if isinstance(sentence1, str): + sentence1 = [sentence1] + sentence1 = sentence1 * len(sentence2) + + max_seq_length = self.sequence_length + feature = self.tokenizer( + sentence1, + sentence2, + padding='max_length', + truncation=True, + max_length=max_seq_length, + return_tensors='pt') + if 'labels' in data: + labels = data['labels'] + feature['labels'] = labels + if 'qid' in data: + qid = data['qid'] + feature['qid'] = qid + return feature + + @PREPROCESSORS.register_module( Fields.nlp, module_name=Preprocessors.nli_tokenizer) @PREPROCESSORS.register_module( @@ -298,6 +355,51 @@ class SingleSentenceClassificationPreprocessor(NLPTokenizerPreprocessorBase): super().__init__(model_dir, pair=False, mode=mode, **kwargs) +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.sentence_embedding) +class SentenceEmbeddingPreprocessor(NLPTokenizerPreprocessorBase): + """The tokenizer preprocessor used in sentence embedding. + """ + + def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): + kwargs['truncation'] = kwargs.get('truncation', True) + kwargs['padding'] = kwargs.get( + 'padding', False if mode == ModeKeys.INFERENCE else 'max_length') + kwargs['max_length'] = kwargs.pop('sequence_length', 128) + super().__init__(model_dir, pair=False, mode=mode, **kwargs) + + def __call__(self, data: Union[str, Dict]) -> Dict[str, Any]: + """process the raw input data + + Args: + data Dict: + keys: "source_sentence" && "sentences_to_compare" + values: list of sentences + Example: + {"source_sentence": ["how long it take to get a master's degree"], + "sentences_to_compare": ["On average, students take about 18 to 24 months + to complete a master's degree.", + "On the other hand, some students prefer to go at a slower pace + and choose to take several years to complete their studies.", + "It can take anywhere from two semesters"]} + Returns: + Dict[str, Any]: the preprocessed data + """ + source_sentence = data['source_sentence'] + compare_sentences = data['sentences_to_compare'] + sentences = [] + sentences.append(source_sentence[0]) + for sent in compare_sentences: + sentences.append(sent) + + tokenized_inputs = self.tokenizer( + sentences, + return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None, + padding=True, + truncation=True) + return tokenized_inputs + + @PREPROCESSORS.register_module( Fields.nlp, module_name=Preprocessors.zero_shot_cls_tokenizer) class ZeroShotClassificationPreprocessor(NLPTokenizerPreprocessorBase): diff --git a/modelscope/trainers/__init__.py b/modelscope/trainers/__init__.py index 8f8938c8..a632642a 100644 --- a/modelscope/trainers/__init__.py +++ b/modelscope/trainers/__init__.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: ImagePortraitEnhancementTrainer, MovieSceneSegmentationTrainer) from .multi_modal import CLIPTrainer - from .nlp import SequenceClassificationTrainer + from .nlp import SequenceClassificationTrainer, PassageRankingTrainer from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer from .trainer import EpochBasedTrainer @@ -25,7 +25,7 @@ else: 'ImagePortraitEnhancementTrainer', 'MovieSceneSegmentationTrainer' ], 'multi_modal': ['CLIPTrainer'], - 'nlp': ['SequenceClassificationTrainer'], + 'nlp': ['SequenceClassificationTrainer', 'PassageRankingTrainer'], 'nlp_trainer': ['NlpEpochBasedTrainer', 'VecoTrainer'], 'trainer': ['EpochBasedTrainer'] } diff --git a/modelscope/trainers/nlp/__init__.py b/modelscope/trainers/nlp/__init__.py index 7ab8fd70..001cfefc 100644 --- a/modelscope/trainers/nlp/__init__.py +++ b/modelscope/trainers/nlp/__init__.py @@ -6,10 +6,12 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .sequence_classification_trainer import SequenceClassificationTrainer from .csanmt_translation_trainer import CsanmtTranslationTrainer + from .passage_ranking_trainer import PassageRankingTranier else: _import_structure = { 'sequence_classification_trainer': ['SequenceClassificationTrainer'], 'csanmt_translation_trainer': ['CsanmtTranslationTrainer'], + 'passage_ranking_trainer': ['PassageRankingTrainer'] } import sys diff --git a/modelscope/trainers/nlp/passage_ranking_trainer.py b/modelscope/trainers/nlp/passage_ranking_trainer.py new file mode 100644 index 00000000..e54c2904 --- /dev/null +++ b/modelscope/trainers/nlp/passage_ranking_trainer.py @@ -0,0 +1,197 @@ +import time +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch import nn +from torch.utils.data import DataLoader, Dataset + +from modelscope.metainfo import Trainers +from modelscope.models.base import Model, TorchModel +from modelscope.msdatasets.ms_dataset import MsDataset +from modelscope.preprocessors.base import Preprocessor +from modelscope.trainers.base import BaseTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer +from modelscope.utils.constant import DEFAULT_MODEL_REVISION +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@dataclass +class GroupCollator(): + """ + Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg] + and pass batch separately to the actual collator. + Abstract out data detail for the model. + """ + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + if isinstance(features[0], list): + features = sum(features, []) + keys = features[0].keys() + batch = {k: list() for k in keys} + for ele in features: + for k, v in ele.items(): + batch[k].append(v) + batch = {k: torch.cat(v, dim=0) for k, v in batch.items()} + return batch + + +@TRAINERS.register_module(module_name=Trainers.nlp_passage_ranking_trainer) +class PassageRankingTrainer(NlpEpochBasedTrainer): + + def __init__( + self, + model: Optional[Union[TorchModel, nn.Module, str]] = None, + cfg_file: Optional[str] = None, + cfg_modify_fn: Optional[Callable] = None, + arg_parse_fn: Optional[Callable] = None, + data_collator: Optional[Callable] = None, + train_dataset: Optional[Union[MsDataset, Dataset]] = None, + eval_dataset: Optional[Union[MsDataset, Dataset]] = None, + preprocessor: Optional[Preprocessor] = None, + optimizers: Tuple[torch.optim.Optimizer, + torch.optim.lr_scheduler._LRScheduler] = (None, + None), + model_revision: Optional[str] = DEFAULT_MODEL_REVISION, + **kwargs): + + if data_collator is None: + data_collator = GroupCollator() + + super().__init__( + model=model, + cfg_file=cfg_file, + cfg_modify_fn=cfg_modify_fn, + arg_parse_fn=arg_parse_fn, + data_collator=data_collator, + preprocessor=preprocessor, + optimizers=optimizers, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + model_revision=model_revision, + **kwargs) + + def compute_mrr(self, result, k=10): + mrr = 0 + for res in result.values(): + sorted_res = sorted(res, key=lambda x: x[0], reverse=True) + ar = 0 + for index, ele in enumerate(sorted_res[:k]): + if str(ele[1]) == '1': + ar = 1.0 / (index + 1) + break + mrr += ar + return mrr / len(result) + + def compute_ndcg(self, result, k=10): + ndcg = 0 + from sklearn import ndcg_score + for res in result.values(): + sorted_res = sorted(res, key=lambda x: [0], reverse=True) + labels = np.array([[ele[1] for ele in sorted_res]]) + scores = np.array([[ele[0] for ele in sorted_res]]) + ndcg += float(ndcg_score(labels, scores, k=k)) + ndcg = ndcg / len(result) + return ndcg + + def evaluate(self, + checkpoint_path: Optional[str] = None, + *args, + **kwargs) -> Dict[str, float]: + """evaluate a dataset + + evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path` + does not exist, read from the config file. + + Args: + checkpoint_path (Optional[str], optional): the model path. Defaults to None. + + Returns: + Dict[str, float]: the results about the evaluation + Example: + {"accuracy": 0.5091743119266054, "f1": 0.673780487804878} + """ + from modelscope.models.nlp import PassageRanking + # get the raw online dataset + self.eval_dataloader = self._build_dataloader_with_dataset( + self.eval_dataset, + **self.cfg.evaluation.get('dataloader', {}), + collate_fn=self.eval_data_collator) + # generate a standard dataloader + # generate a model + if checkpoint_path is not None: + model = PassageRanking.from_pretrained(checkpoint_path) + else: + model = self.model + + # copy from easynlp (start) + model.eval() + total_samples = 0 + + logits_list = list() + label_list = list() + qid_list = list() + + total_spent_time = 0.0 + device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + model.to(device) + for _step, batch in enumerate(self.eval_dataloader): + try: + batch = { + key: + val.to(device) if isinstance(val, torch.Tensor) else val + for key, val in batch.items() + } + except RuntimeError: + batch = {key: val for key, val in batch.items()} + + infer_start_time = time.time() + with torch.no_grad(): + label_ids = batch.pop('labels').detach().cpu().numpy() + qids = batch.pop('qid').detach().cpu().numpy() + outputs = model(batch) + infer_end_time = time.time() + total_spent_time += infer_end_time - infer_start_time + total_samples += self.eval_dataloader.batch_size + + assert 'scores' in outputs + logits = outputs['scores'] + + label_list.extend(label_ids) + logits_list.extend(logits) + qid_list.extend(qids) + + logger.info('Inference time = {:.2f}s, [{:.4f} ms / sample] '.format( + total_spent_time, total_spent_time * 1000 / total_samples)) + + rank_result = {} + for qid, score, label in zip(qid_list, logits_list, label_list): + if qid not in rank_result: + rank_result[qid] = [] + rank_result[qid].append((score, label)) + + for qid in rank_result: + rank_result[qid] = sorted(rank_result[qid], key=lambda x: x[0]) + + eval_outputs = list() + for metric in self.metrics: + if metric.startswith('mrr'): + k = metric.split('@')[-1] + k = int(k) + mrr = self.compute_mrr(rank_result, k=k) + logger.info('{}: {}'.format(metric, mrr)) + eval_outputs.append((metric, mrr)) + elif metric.startswith('ndcg'): + k = metric.split('@')[-1] + k = int(k) + ndcg = self.compute_ndcg(rank_result, k=k) + logger.info('{}: {}'.format(metric, ndcg)) + eval_outputs.append(('ndcg', ndcg)) + else: + raise NotImplementedError('Metric %s not implemented' % metric) + + return dict(eval_outputs) diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index 63a231b3..8dc75a65 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -345,12 +345,12 @@ class EpochBasedTrainer(BaseTrainer): type=self.cfg.task, mode=mode, datasets=datasets) return build_task_dataset(cfg, self.cfg.task) else: - task_data_config.update( - dict( - mode=mode, - datasets=datasets, - preprocessor=preprocessor)) - return build_task_dataset(task_data_config, self.cfg.task) + # avoid add no str value datasets, preprocessors in cfg + task_data_build_config = ConfigDict( + mode=mode, datasets=datasets, preprocessor=preprocessor) + task_data_build_config.update(task_data_config) + return build_task_dataset(task_data_build_config, + self.cfg.task) except Exception: if isinstance(datasets, (List, Tuple)) or preprocessor is not None: return TorchTaskDataset( diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 6d84925c..57d38da7 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -89,6 +89,8 @@ class NLPTasks(object): sentiment_analysis = 'sentiment-analysis' sentence_similarity = 'sentence-similarity' text_classification = 'text-classification' + sentence_embedding = 'sentence-embedding' + passage_ranking = 'passage-ranking' relation_extraction = 'relation-extraction' zero_shot = 'zero-shot' translation = 'translation' diff --git a/tests/pipelines/test_passage_ranking.py b/tests/pipelines/test_passage_ranking.py new file mode 100644 index 00000000..5faa365e --- /dev/null +++ b/tests/pipelines/test_passage_ranking.py @@ -0,0 +1,61 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import shutil +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import PassageRanking +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import PassageRankingPipeline +from modelscope.preprocessors import PassageRankingPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class PassageRankingTest(unittest.TestCase): + model_id = 'damo/nlp_corom_passage-ranking_english-base' + inputs = { + 'source_sentence': ["how long it take to get a master's degree"], + 'sentences_to_compare': [ + "On average, students take about 18 to 24 months to complete a master's degree.", + 'On the other hand, some students prefer to go at a slower pace and choose to take ' + 'several years to complete their studies.', + 'It can take anywhere from two semesters' + ] + } + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + tokenizer = PassageRankingPreprocessor(cache_path) + model = PassageRanking.from_pretrained(cache_path) + pipeline1 = PassageRankingPipeline(model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.passage_ranking, model=model, preprocessor=tokenizer) + print(f'sentence: {self.inputs}\n' + f'pipeline1:{pipeline1(input=self.inputs)}') + print() + print(f'pipeline2: {pipeline2(input=self.inputs)}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + tokenizer = PassageRankingPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.passage_ranking, model=model, preprocessor=tokenizer) + print(pipeline_ins(input=self.inputs)) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.passage_ranking, model=self.model_id) + print(pipeline_ins(input=self.inputs)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.passage_ranking) + print(pipeline_ins(input=self.inputs)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_sentence_embedding.py b/tests/pipelines/test_sentence_embedding.py new file mode 100644 index 00000000..739dd7ab --- /dev/null +++ b/tests/pipelines/test_sentence_embedding.py @@ -0,0 +1,82 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import shutil +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import SentenceEmbedding +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import SentenceEmbeddingPipeline +from modelscope.preprocessors import SentenceEmbeddingPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class SentenceEmbeddingTest(unittest.TestCase): + model_id = 'damo/nlp_corom_sentence-embedding_english-base' + inputs = { + 'source_sentence': ["how long it take to get a master's degree"], + 'sentences_to_compare': [ + "On average, students take about 18 to 24 months to complete a master's degree.", + 'On the other hand, some students prefer to go at a slower pace and choose to take ', + 'several years to complete their studies.', + 'It can take anywhere from two semesters' + ] + } + + inputs2 = { + 'source_sentence': ["how long it take to get a master's degree"], + 'sentences_to_compare': [ + "On average, students take about 18 to 24 months to complete a master's degree." + ] + } + + inputs3 = { + 'source_sentence': ["how long it take to get a master's degree"], + 'sentences_to_compare': [] + } + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + tokenizer = SentenceEmbeddingPreprocessor(cache_path) + model = SentenceEmbedding.from_pretrained(cache_path) + pipeline1 = SentenceEmbeddingPipeline(model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.sentence_embedding, model=model, preprocessor=tokenizer) + print(f'inputs: {self.inputs}\n' + f'pipeline1:{pipeline1(input=self.inputs)}') + print() + print(f'pipeline2: {pipeline2(input=self.inputs)}') + print() + print(f'inputs: {self.inputs2}\n' + f'pipeline1:{pipeline1(input=self.inputs2)}') + print() + print(f'pipeline2: {pipeline2(input=self.inputs2)}') + print(f'inputs: {self.inputs3}\n' + f'pipeline1:{pipeline1(input=self.inputs3)}') + print() + print(f'pipeline2: {pipeline2(input=self.inputs3)}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + tokenizer = SentenceEmbeddingPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.sentence_embedding, model=model, preprocessor=tokenizer) + print(pipeline_ins(input=self.inputs)) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.sentence_embedding, model=self.model_id) + print(pipeline_ins(input=self.inputs)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.sentence_embedding) + print(pipeline_ins(input=self.inputs)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_finetune_passage_ranking.py b/tests/trainers/test_finetune_passage_ranking.py new file mode 100644 index 00000000..f833f981 --- /dev/null +++ b/tests/trainers/test_finetune_passage_ranking.py @@ -0,0 +1,133 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest +from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union + +import torch +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from modelscope.metainfo import Trainers +from modelscope.models import Model +from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline +from modelscope.trainers import build_trainer +from modelscope.utils.constant import ModelFile, Tasks + + +class TestFinetuneSequenceClassification(unittest.TestCase): + inputs = { + 'source_sentence': ["how long it take to get a master's degree"], + 'sentences_to_compare': [ + "On average, students take about 18 to 24 months to complete a master's degree.", + 'On the other hand, some students prefer to go at a slower pace and choose to take ' + 'several years to complete their studies.', + 'It can take anywhere from two semesters' + ] + } + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + def finetune(self, + model_id, + train_dataset, + eval_dataset, + name=Trainers.nlp_passage_ranking_trainer, + cfg_modify_fn=None, + **kwargs): + kwargs = dict( + model=model_id, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + work_dir=self.tmp_dir, + cfg_modify_fn=cfg_modify_fn, + **kwargs) + + os.environ['LOCAL_RANK'] = '0' + trainer = build_trainer(name=name, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + + def test_finetune_msmarco(self): + + def cfg_modify_fn(cfg): + cfg.task = 'passage-ranking' + cfg['preprocessor'] = {'type': 'passage-ranking'} + cfg.train.optimizer.lr = 2e-5 + cfg['dataset'] = { + 'train': { + 'type': 'bert', + 'query_sequence': 'query', + 'pos_sequence': 'positive_passages', + 'neg_sequence': 'negative_passages', + 'passage_text_fileds': ['title', 'text'], + 'qid_field': 'query_id' + }, + 'val': { + 'type': 'bert', + 'query_sequence': 'query', + 'pos_sequence': 'positive_passages', + 'neg_sequence': 'negative_passages', + 'passage_text_fileds': ['title', 'text'], + 'qid_field': 'query_id' + }, + } + cfg['train']['neg_samples'] = 4 + cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30 + cfg.train.max_epochs = 1 + cfg.train.train_batch_size = 4 + cfg.train.lr_scheduler = { + 'type': 'LinearLR', + 'start_factor': 1.0, + 'end_factor': 0.0, + 'options': { + 'by_epoch': False + } + } + cfg.train.hooks = [{ + 'type': 'CheckpointHook', + 'interval': 1 + }, { + 'type': 'TextLoggerHook', + 'interval': 1 + }, { + 'type': 'IterTimerHook' + }, { + 'type': 'EvaluationHook', + 'by_epoch': False, + 'interval': 3000 + }] + return cfg + + # load dataset + ds = MsDataset.load('passage-ranking-demo', 'zyznull') + train_ds = ds['train'].to_hf_dataset() + dev_ds = ds['train'].to_hf_dataset() + + self.finetune( + model_id='damo/nlp_corom_passage-ranking_english-base', + train_dataset=train_ds, + eval_dataset=dev_ds, + cfg_modify_fn=cfg_modify_fn) + + output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) + self.pipeline_passage_ranking(output_dir) + + def pipeline_passage_ranking(self, model_dir): + model = Model.from_pretrained(model_dir) + pipeline_ins = pipeline(task=Tasks.passage_ranking, model=model) + print(pipeline_ins(input=self.inputs)) + + +if __name__ == '__main__': + unittest.main()