Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9856179master
| @@ -193,6 +193,8 @@ class Pipelines(object): | |||||
| plug_generation = 'plug-generation' | plug_generation = 'plug-generation' | ||||
| faq_question_answering = 'faq-question-answering' | faq_question_answering = 'faq-question-answering' | ||||
| conversational_text_to_sql = 'conversational-text-to-sql' | conversational_text_to_sql = 'conversational-text-to-sql' | ||||
| sentence_embedding = 'sentence-embedding' | |||||
| passage_ranking = 'passage-ranking' | |||||
| relation_extraction = 'relation-extraction' | relation_extraction = 'relation-extraction' | ||||
| document_segmentation = 'document-segmentation' | document_segmentation = 'document-segmentation' | ||||
| @@ -245,6 +247,7 @@ class Trainers(object): | |||||
| dialog_intent_trainer = 'dialog-intent-trainer' | dialog_intent_trainer = 'dialog-intent-trainer' | ||||
| nlp_base_trainer = 'nlp-base-trainer' | nlp_base_trainer = 'nlp-base-trainer' | ||||
| nlp_veco_trainer = 'nlp-veco-trainer' | nlp_veco_trainer = 'nlp-veco-trainer' | ||||
| nlp_passage_ranking_trainer = 'nlp-passage-ranking-trainer' | |||||
| # audio trainers | # audio trainers | ||||
| speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | ||||
| @@ -272,6 +275,7 @@ class Preprocessors(object): | |||||
| # nlp preprocessor | # nlp preprocessor | ||||
| sen_sim_tokenizer = 'sen-sim-tokenizer' | sen_sim_tokenizer = 'sen-sim-tokenizer' | ||||
| cross_encoder_tokenizer = 'cross-encoder-tokenizer' | |||||
| bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' | bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' | ||||
| text_gen_tokenizer = 'text-gen-tokenizer' | text_gen_tokenizer = 'text-gen-tokenizer' | ||||
| token_cls_tokenizer = 'token-cls-tokenizer' | token_cls_tokenizer = 'token-cls-tokenizer' | ||||
| @@ -284,6 +288,8 @@ class Preprocessors(object): | |||||
| sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' | sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' | ||||
| zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | ||||
| text_error_correction = 'text-error-correction' | text_error_correction = 'text-error-correction' | ||||
| sentence_embedding = 'sentence-embedding' | |||||
| passage_ranking = 'passage-ranking' | |||||
| sequence_labeling_tokenizer = 'sequence-labeling-tokenizer' | sequence_labeling_tokenizer = 'sequence-labeling-tokenizer' | ||||
| word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' | word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' | ||||
| fill_mask = 'fill-mask' | fill_mask = 'fill-mask' | ||||
| @@ -29,6 +29,8 @@ if TYPE_CHECKING: | |||||
| SingleBackboneTaskModelBase, | SingleBackboneTaskModelBase, | ||||
| TokenClassificationModel) | TokenClassificationModel) | ||||
| from .token_classification import SbertForTokenClassification | from .token_classification import SbertForTokenClassification | ||||
| from .sentence_embedding import SentenceEmbedding | |||||
| from .passage_ranking import PassageRanking | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| @@ -62,6 +64,8 @@ else: | |||||
| 'SingleBackboneTaskModelBase', 'TokenClassificationModel' | 'SingleBackboneTaskModelBase', 'TokenClassificationModel' | ||||
| ], | ], | ||||
| 'token_classification': ['SbertForTokenClassification'], | 'token_classification': ['SbertForTokenClassification'], | ||||
| 'sentence_embedding': ['SentenceEmbedding'], | |||||
| 'passage_ranking': ['PassageRanking'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,78 @@ | |||||
| from typing import Any, Dict | |||||
| import numpy as np | |||||
| import torch | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models import TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.models.nlp import SbertForSequenceClassification | |||||
| from modelscope.models.nlp.structbert import SbertPreTrainedModel | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.utils.constant import Tasks | |||||
| __all__ = ['PassageRanking'] | |||||
| @MODELS.register_module(Tasks.passage_ranking, module_name=Models.bert) | |||||
| class PassageRanking(SbertForSequenceClassification, SbertPreTrainedModel): | |||||
| base_model_prefix: str = 'bert' | |||||
| supports_gradient_checkpointing = True | |||||
| _keys_to_ignore_on_load_missing = [r'position_ids'] | |||||
| def __init__(self, config, model_dir, *args, **kwargs): | |||||
| if hasattr(config, 'base_model_prefix'): | |||||
| PassageRanking.base_model_prefix = config.base_model_prefix | |||||
| super().__init__(config, model_dir) | |||||
| self.train_batch_size = kwargs.get('train_batch_size', 4) | |||||
| self.register_buffer( | |||||
| 'target_label', | |||||
| torch.zeros(self.train_batch_size, dtype=torch.long)) | |||||
| def build_base_model(self): | |||||
| from .structbert import SbertModel | |||||
| return SbertModel(self.config, add_pooling_layer=True) | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||||
| outputs = self.base_model.forward(**input) | |||||
| # backbone model should return pooled_output as its second output | |||||
| pooled_output = outputs[1] | |||||
| pooled_output = self.dropout(pooled_output) | |||||
| logits = self.classifier(pooled_output) | |||||
| if self.base_model.training: | |||||
| scores = logits.view(self.train_batch_size, -1) | |||||
| loss_fct = torch.nn.CrossEntropyLoss() | |||||
| loss = loss_fct(scores, self.target_label) | |||||
| return {OutputKeys.LOGITS: logits, OutputKeys.LOSS: loss} | |||||
| return {OutputKeys.LOGITS: logits} | |||||
| def sigmoid(self, logits): | |||||
| return np.exp(logits) / (1 + np.exp(logits)) | |||||
| def postprocess(self, inputs: Dict[str, np.ndarray], | |||||
| **kwargs) -> Dict[str, np.ndarray]: | |||||
| logits = inputs['logits'].squeeze(-1).detach().cpu().numpy() | |||||
| logits = self.sigmoid(logits).tolist() | |||||
| result = {OutputKeys.SCORES: logits} | |||||
| return result | |||||
| @classmethod | |||||
| def _instantiate(cls, **kwargs): | |||||
| """Instantiate the model. | |||||
| @param kwargs: Input args. | |||||
| model_dir: The model dir used to load the checkpoint and the label information. | |||||
| num_labels: An optional arg to tell the model how many classes to initialize. | |||||
| Method will call utils.parse_label_mapping if num_labels not supplied. | |||||
| If num_labels is not found, the model will use the default setting (1 classes). | |||||
| @return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained | |||||
| """ | |||||
| num_labels = kwargs.get('num_labels', 1) | |||||
| model_args = {} if num_labels is None else {'num_labels': num_labels} | |||||
| return super(SbertPreTrainedModel, PassageRanking).from_pretrained( | |||||
| pretrained_model_name_or_path=kwargs.get('model_dir'), | |||||
| model_dir=kwargs.get('model_dir'), | |||||
| **model_args) | |||||
| @@ -0,0 +1,74 @@ | |||||
| import os | |||||
| from typing import Any, Dict | |||||
| import json | |||||
| import numpy as np | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models import TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.models.nlp.structbert import SbertPreTrainedModel | |||||
| from modelscope.utils.constant import Tasks | |||||
| __all__ = ['SentenceEmbedding'] | |||||
| @MODELS.register_module(Tasks.sentence_embedding, module_name=Models.bert) | |||||
| class SentenceEmbedding(TorchModel, SbertPreTrainedModel): | |||||
| base_model_prefix: str = 'bert' | |||||
| supports_gradient_checkpointing = True | |||||
| _keys_to_ignore_on_load_missing = [r'position_ids'] | |||||
| def __init__(self, config, model_dir): | |||||
| super().__init__(model_dir) | |||||
| self.config = config | |||||
| setattr(self, self.base_model_prefix, self.build_base_model()) | |||||
| def build_base_model(self): | |||||
| from .structbert import SbertModel | |||||
| return SbertModel(self.config, add_pooling_layer=False) | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||||
| """return the result by the model | |||||
| Args: | |||||
| input (Dict[str, Any]): the preprocessed data | |||||
| Returns: | |||||
| Dict[str, np.ndarray]: results | |||||
| Example: | |||||
| { | |||||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||||
| } | |||||
| """ | |||||
| return self.base_model(**input) | |||||
| def postprocess(self, inputs: Dict[str, np.ndarray], | |||||
| **kwargs) -> Dict[str, np.ndarray]: | |||||
| embs = inputs['last_hidden_state'][:, 0].cpu().numpy() | |||||
| num_sent = embs.shape[0] | |||||
| if num_sent >= 2: | |||||
| scores = np.dot(embs[0:1, ], np.transpose(embs[1:, ], | |||||
| (1, 0))).tolist()[0] | |||||
| else: | |||||
| scores = [] | |||||
| result = {'text_embedding': embs, 'scores': scores} | |||||
| return result | |||||
| @classmethod | |||||
| def _instantiate(cls, **kwargs): | |||||
| """Instantiate the model. | |||||
| @param kwargs: Input args. | |||||
| model_dir: The model dir used to load the checkpoint and the label information. | |||||
| @return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained | |||||
| """ | |||||
| model_args = {} | |||||
| return super(SbertPreTrainedModel, SentenceEmbedding).from_pretrained( | |||||
| pretrained_model_name_or_path=kwargs.get('model_dir'), | |||||
| model_dir=kwargs.get('model_dir'), | |||||
| **model_args) | |||||
| @@ -11,12 +11,14 @@ if TYPE_CHECKING: | |||||
| from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset | from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset | ||||
| from .movie_scene_segmentation import MovieSceneSegmentationDataset | from .movie_scene_segmentation import MovieSceneSegmentationDataset | ||||
| from .video_summarization_dataset import VideoSummarizationDataset | from .video_summarization_dataset import VideoSummarizationDataset | ||||
| from .passage_ranking_dataset import PassageRankingDataset | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'base': ['TaskDataset'], | 'base': ['TaskDataset'], | ||||
| 'builder': ['TASK_DATASETS', 'build_task_dataset'], | 'builder': ['TASK_DATASETS', 'build_task_dataset'], | ||||
| 'torch_base_dataset': ['TorchTaskDataset'], | 'torch_base_dataset': ['TorchTaskDataset'], | ||||
| 'passage_ranking_dataset': ['PassageRankingDataset'], | |||||
| 'veco_dataset': ['VecoDataset'], | 'veco_dataset': ['VecoDataset'], | ||||
| 'image_instance_segmentation_coco_dataset': | 'image_instance_segmentation_coco_dataset': | ||||
| ['ImageInstanceSegmentationCocoDataset'], | ['ImageInstanceSegmentationCocoDataset'], | ||||
| @@ -0,0 +1,151 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import random | |||||
| from dataclasses import dataclass | |||||
| from typing import Any, Dict, List, Tuple, Union | |||||
| import torch | |||||
| from datasets import Dataset, IterableDataset, concatenate_datasets | |||||
| from torch.utils.data import ConcatDataset | |||||
| from transformers import DataCollatorWithPadding | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.utils.constant import ModeKeys, Tasks | |||||
| from .base import TaskDataset | |||||
| from .builder import TASK_DATASETS | |||||
| from .torch_base_dataset import TorchTaskDataset | |||||
| @TASK_DATASETS.register_module( | |||||
| group_key=Tasks.passage_ranking, module_name=Models.bert) | |||||
| class PassageRankingDataset(TorchTaskDataset): | |||||
| def __init__(self, | |||||
| datasets: Union[Any, List[Any]], | |||||
| mode, | |||||
| preprocessor=None, | |||||
| *args, | |||||
| **kwargs): | |||||
| self.seed = kwargs.get('seed', 42) | |||||
| self.permutation = None | |||||
| self.datasets = None | |||||
| self.dataset_config = kwargs | |||||
| self.query_sequence = self.dataset_config.get('query_sequence', | |||||
| 'query') | |||||
| self.pos_sequence = self.dataset_config.get('pos_sequence', | |||||
| 'positive_passages') | |||||
| self.neg_sequence = self.dataset_config.get('neg_sequence', | |||||
| 'negative_passages') | |||||
| self.passage_text_fileds = self.dataset_config.get( | |||||
| 'passage_text_fileds', ['title', 'text']) | |||||
| self.qid_field = self.dataset_config.get('qid_field', 'query_id') | |||||
| if mode == ModeKeys.TRAIN: | |||||
| train_config = kwargs.get('train', {}) | |||||
| self.neg_samples = train_config.get('neg_samples', 4) | |||||
| super().__init__(datasets, mode, preprocessor, **kwargs) | |||||
| def __getitem__(self, index) -> Any: | |||||
| if self.mode == ModeKeys.TRAIN: | |||||
| return self.__get_train_item__(index) | |||||
| else: | |||||
| return self.__get_test_item__(index) | |||||
| def __get_test_item__(self, index): | |||||
| group = self._inner_dataset[index] | |||||
| labels = [] | |||||
| qry = group[self.query_sequence] | |||||
| pos_sequences = group[self.pos_sequence] | |||||
| pos_sequences = [ | |||||
| ' '.join([ele[key] for key in self.passage_text_fileds]) | |||||
| for ele in pos_sequences | |||||
| ] | |||||
| labels.extend([1] * len(pos_sequences)) | |||||
| neg_sequences = group[self.neg_sequence] | |||||
| neg_sequences = [ | |||||
| ' '.join([ele[key] for key in self.passage_text_fileds]) | |||||
| for ele in neg_sequences | |||||
| ] | |||||
| labels.extend([0] * len(neg_sequences)) | |||||
| qid = group[self.qid_field] | |||||
| examples = pos_sequences + neg_sequences | |||||
| sample = { | |||||
| 'qid': torch.LongTensor([int(qid)] * len(labels)), | |||||
| self.preprocessor.first_sequence: qry, | |||||
| self.preprocessor.second_sequence: examples, | |||||
| 'labels': torch.LongTensor(labels) | |||||
| } | |||||
| return self.prepare_sample(sample) | |||||
| def __get_train_item__(self, index): | |||||
| group = self._inner_dataset[index] | |||||
| qry = group[self.query_sequence] | |||||
| pos_sequences = group[self.pos_sequence] | |||||
| pos_sequences = [ | |||||
| ' '.join([ele[key] for key in self.passage_text_fileds]) | |||||
| for ele in pos_sequences | |||||
| ] | |||||
| neg_sequences = group[self.neg_sequence] | |||||
| neg_sequences = [ | |||||
| ' '.join([ele[key] for key in self.passage_text_fileds]) | |||||
| for ele in neg_sequences | |||||
| ] | |||||
| pos_psg = random.choice(pos_sequences) | |||||
| if len(neg_sequences) < self.neg_samples: | |||||
| negs = random.choices(neg_sequences, k=self.neg_samples) | |||||
| else: | |||||
| negs = random.sample(neg_sequences, k=self.neg_samples) | |||||
| examples = [pos_psg] + negs | |||||
| sample = { | |||||
| self.preprocessor.first_sequence: qry, | |||||
| self.preprocessor.second_sequence: examples, | |||||
| } | |||||
| return self.prepare_sample(sample) | |||||
| def __len__(self): | |||||
| return len(self._inner_dataset) | |||||
| def prepare_dataset(self, datasets: Union[Any, List[Any]]) -> Any: | |||||
| """Prepare a dataset. | |||||
| User can process the input datasets in a whole dataset perspective. | |||||
| This method gives a default implementation of datasets merging, user can override this | |||||
| method to write custom logics. | |||||
| Args: | |||||
| datasets: The original dataset(s) | |||||
| Returns: A single dataset, which may be created after merging. | |||||
| """ | |||||
| if isinstance(datasets, List): | |||||
| if len(datasets) == 1: | |||||
| return datasets[0] | |||||
| elif len(datasets) > 1: | |||||
| return ConcatDataset(datasets) | |||||
| else: | |||||
| return datasets | |||||
| def prepare_sample(self, data): | |||||
| """Preprocess the data fetched from the inner_dataset. | |||||
| If the preprocessor is None, the original data will be returned, else the preprocessor will be called. | |||||
| User can override this method to implement custom logics. | |||||
| Args: | |||||
| data: The data fetched from the dataset. | |||||
| Returns: The processed data. | |||||
| """ | |||||
| return self.preprocessor( | |||||
| data) if self.preprocessor is not None else data | |||||
| @@ -387,19 +387,14 @@ TASK_OUTPUTS = { | |||||
| # "output": "我想吃苹果" | # "output": "我想吃苹果" | ||||
| # } | # } | ||||
| Tasks.text_error_correction: [OutputKeys.OUTPUT], | Tasks.text_error_correction: [OutputKeys.OUTPUT], | ||||
| Tasks.sentence_embedding: [OutputKeys.TEXT_EMBEDDING, OutputKeys.SCORES], | |||||
| Tasks.passage_ranking: [OutputKeys.SCORES], | |||||
| # text generation result for single sample | # text generation result for single sample | ||||
| # { | # { | ||||
| # "text": "this is the text generated by a model." | # "text": "this is the text generated by a model." | ||||
| # } | # } | ||||
| Tasks.text_generation: [OutputKeys.TEXT], | Tasks.text_generation: [OutputKeys.TEXT], | ||||
| # text feature extraction for single sample | |||||
| # { | |||||
| # "text_embedding": np.array with shape [1, D] | |||||
| # } | |||||
| Tasks.sentence_embedding: [OutputKeys.TEXT_EMBEDDING], | |||||
| # fill mask result for single sample | # fill mask result for single sample | ||||
| # { | # { | ||||
| # "text": "this is the text which masks filled by model." | # "text": "this is the text which masks filled by model." | ||||
| @@ -17,6 +17,11 @@ PIPELINES = Registry('pipelines') | |||||
| DEFAULT_MODEL_FOR_PIPELINE = { | DEFAULT_MODEL_FOR_PIPELINE = { | ||||
| # TaskName: (pipeline_module_name, model_repo) | # TaskName: (pipeline_module_name, model_repo) | ||||
| Tasks.sentence_embedding: | |||||
| (Pipelines.sentence_embedding, | |||||
| 'damo/nlp_corom_sentence-embedding_english-base'), | |||||
| Tasks.passage_ranking: (Pipelines.passage_ranking, | |||||
| 'damo/nlp_corom_passage-ranking_english-base'), | |||||
| Tasks.word_segmentation: | Tasks.word_segmentation: | ||||
| (Pipelines.word_segmentation, | (Pipelines.word_segmentation, | ||||
| 'damo/nlp_structbert_word-segmentation_chinese-base'), | 'damo/nlp_structbert_word-segmentation_chinese-base'), | ||||
| @@ -25,7 +25,8 @@ if TYPE_CHECKING: | |||||
| from .translation_pipeline import TranslationPipeline | from .translation_pipeline import TranslationPipeline | ||||
| from .word_segmentation_pipeline import WordSegmentationPipeline | from .word_segmentation_pipeline import WordSegmentationPipeline | ||||
| from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline | from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline | ||||
| from .passage_ranking_pipeline import PassageRankingPipeline | |||||
| from .sentence_embedding_pipeline import SentenceEmbeddingPipeline | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'conversational_text_to_sql_pipeline': | 'conversational_text_to_sql_pipeline': | ||||
| @@ -55,6 +56,8 @@ else: | |||||
| 'word_segmentation_pipeline': ['WordSegmentationPipeline'], | 'word_segmentation_pipeline': ['WordSegmentationPipeline'], | ||||
| 'zero_shot_classification_pipeline': | 'zero_shot_classification_pipeline': | ||||
| ['ZeroShotClassificationPipeline'], | ['ZeroShotClassificationPipeline'], | ||||
| 'passage_ranking_pipeline': ['PassageRankingPipeline'], | |||||
| 'sentence_embedding_pipeline': ['SentenceEmbeddingPipeline'] | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,58 @@ | |||||
| from typing import Any, Dict, Optional, Union | |||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models import Model | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import PassageRankingPreprocessor, Preprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| __all__ = ['PassageRankingPipeline'] | |||||
| @PIPELINES.register_module( | |||||
| Tasks.passage_ranking, module_name=Pipelines.passage_ranking) | |||||
| class PassageRankingPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: Union[Model, str], | |||||
| preprocessor: Optional[Preprocessor] = None, | |||||
| **kwargs): | |||||
| """Use `model` and `preprocessor` to create a nlp word segment pipeline for prediction. | |||||
| Args: | |||||
| model (str or Model): Supply either a local model dir which supported the WS task, | |||||
| or a model id from the model hub, or a torch model instance. | |||||
| preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for | |||||
| the model if supplied. | |||||
| sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value. | |||||
| """ | |||||
| model = model if isinstance(model, | |||||
| Model) else Model.from_pretrained(model) | |||||
| if preprocessor is None: | |||||
| preprocessor = PassageRankingPreprocessor( | |||||
| model.model_dir if isinstance(model, Model) else model, | |||||
| sequence_length=kwargs.pop('sequence_length', 128)) | |||||
| model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| def forward(self, inputs: Dict[str, Any], | |||||
| **forward_params) -> Dict[str, Any]: | |||||
| with torch.no_grad(): | |||||
| return {**self.model(inputs, **forward_params)} | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """process the prediction results | |||||
| Args: | |||||
| inputs (Dict[str, Any]): _description_ | |||||
| Returns: | |||||
| Dict[str, Any]: the predicted text representation | |||||
| """ | |||||
| pred_list = inputs[OutputKeys.SCORES] | |||||
| return {OutputKeys.SCORES: pred_list} | |||||
| @@ -0,0 +1,60 @@ | |||||
| from typing import Any, Dict, Optional, Union | |||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models import Model | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import (Preprocessor, | |||||
| SentenceEmbeddingPreprocessor) | |||||
| from modelscope.utils.constant import Tasks | |||||
| __all__ = ['SentenceEmbeddingPipeline'] | |||||
| @PIPELINES.register_module( | |||||
| Tasks.sentence_embedding, module_name=Pipelines.sentence_embedding) | |||||
| class SentenceEmbeddingPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: Union[Model, str], | |||||
| preprocessor: Optional[Preprocessor] = None, | |||||
| first_sequence='first_sequence', | |||||
| **kwargs): | |||||
| """Use `model` and `preprocessor` to create a nlp text dual encoder then generates the text representation. | |||||
| Args: | |||||
| model (str or Model): Supply either a local model dir which supported the WS task, | |||||
| or a model id from the model hub, or a torch model instance. | |||||
| preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for | |||||
| the model if supplied. | |||||
| sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value. | |||||
| """ | |||||
| model = model if isinstance(model, | |||||
| Model) else Model.from_pretrained(model) | |||||
| if preprocessor is None: | |||||
| preprocessor = SentenceEmbeddingPreprocessor( | |||||
| model.model_dir if isinstance(model, Model) else model, | |||||
| first_sequence=first_sequence, | |||||
| sequence_length=kwargs.pop('sequence_length', 128)) | |||||
| model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| def forward(self, inputs: Dict[str, Any], | |||||
| **forward_params) -> Dict[str, Any]: | |||||
| with torch.no_grad(): | |||||
| return {**self.model(inputs, **forward_params)} | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """process the prediction results | |||||
| Args: | |||||
| inputs (Dict[str, Any]): _description_ | |||||
| Returns: | |||||
| Dict[str, Any]: the predicted text representation | |||||
| """ | |||||
| embs = inputs[OutputKeys.TEXT_EMBEDDING] | |||||
| scores = inputs[OutputKeys.SCORES] | |||||
| return {OutputKeys.TEXT_EMBEDDING: embs, OutputKeys.SCORES: scores} | |||||
| @@ -23,7 +23,8 @@ if TYPE_CHECKING: | |||||
| ZeroShotClassificationPreprocessor, NERPreprocessor, | ZeroShotClassificationPreprocessor, NERPreprocessor, | ||||
| TextErrorCorrectionPreprocessor, FaqQuestionAnsweringPreprocessor, | TextErrorCorrectionPreprocessor, FaqQuestionAnsweringPreprocessor, | ||||
| SequenceLabelingPreprocessor, RelationExtractionPreprocessor, | SequenceLabelingPreprocessor, RelationExtractionPreprocessor, | ||||
| DocumentSegmentationPreprocessor, FillMaskPoNetPreprocessor) | |||||
| DocumentSegmentationPreprocessor, FillMaskPoNetPreprocessor, | |||||
| PassageRankingPreprocessor) | |||||
| from .space import (DialogIntentPredictionPreprocessor, | from .space import (DialogIntentPredictionPreprocessor, | ||||
| DialogModelingPreprocessor, | DialogModelingPreprocessor, | ||||
| DialogStateTrackingPreprocessor) | DialogStateTrackingPreprocessor) | ||||
| @@ -50,6 +51,7 @@ else: | |||||
| 'SingleSentenceClassificationPreprocessor', | 'SingleSentenceClassificationPreprocessor', | ||||
| 'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | 'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | ||||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | ||||
| 'SentenceEmbeddingPreprocessor', 'PassageRankingPreprocessor', | |||||
| 'TextErrorCorrectionPreprocessor', | 'TextErrorCorrectionPreprocessor', | ||||
| 'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor', | 'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor', | ||||
| 'RelationExtractionPreprocessor', | 'RelationExtractionPreprocessor', | ||||
| @@ -29,6 +29,7 @@ __all__ = [ | |||||
| 'PairSentenceClassificationPreprocessor', | 'PairSentenceClassificationPreprocessor', | ||||
| 'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | 'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | ||||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | ||||
| 'SentenceEmbeddingPreprocessor', 'PassageRankingPreprocessor', | |||||
| 'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor', | 'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor', | ||||
| 'SequenceLabelingPreprocessor', 'RelationExtractionPreprocessor', | 'SequenceLabelingPreprocessor', 'RelationExtractionPreprocessor', | ||||
| 'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor' | 'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor' | ||||
| @@ -100,6 +101,7 @@ class SequenceClassificationPreprocessor(Preprocessor): | |||||
| text_a = new_data[self.first_sequence] | text_a = new_data[self.first_sequence] | ||||
| text_b = new_data.get(self.second_sequence, None) | text_b = new_data.get(self.second_sequence, None) | ||||
| feature = self.tokenizer( | feature = self.tokenizer( | ||||
| text_a, | text_a, | ||||
| text_b, | text_b, | ||||
| @@ -111,7 +113,6 @@ class SequenceClassificationPreprocessor(Preprocessor): | |||||
| rst['input_ids'].append(feature['input_ids']) | rst['input_ids'].append(feature['input_ids']) | ||||
| rst['attention_mask'].append(feature['attention_mask']) | rst['attention_mask'].append(feature['attention_mask']) | ||||
| rst['token_type_ids'].append(feature['token_type_ids']) | rst['token_type_ids'].append(feature['token_type_ids']) | ||||
| return rst | return rst | ||||
| @@ -268,6 +269,62 @@ class NLPTokenizerPreprocessorBase(Preprocessor): | |||||
| output[OutputKeys.LABELS] = labels | output[OutputKeys.LABELS] = labels | ||||
| @PREPROCESSORS.register_module( | |||||
| Fields.nlp, module_name=Preprocessors.passage_ranking) | |||||
| class PassageRankingPreprocessor(NLPTokenizerPreprocessorBase): | |||||
| """The tokenizer preprocessor used in passage ranking model. | |||||
| """ | |||||
| def __init__(self, | |||||
| model_dir: str, | |||||
| mode=ModeKeys.INFERENCE, | |||||
| *args, | |||||
| **kwargs): | |||||
| """preprocess the data | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super().__init__(model_dir, pair=True, mode=mode, *args, **kwargs) | |||||
| self.model_dir: str = model_dir | |||||
| self.first_sequence: str = kwargs.pop('first_sequence', | |||||
| 'source_sentence') | |||||
| self.second_sequence = kwargs.pop('second_sequence', | |||||
| 'sentences_to_compare') | |||||
| self.sequence_length = kwargs.pop('sequence_length', 128) | |||||
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) | |||||
| @type_assert(object, (str, tuple, Dict)) | |||||
| def __call__(self, data: Union[tuple, Dict]) -> Dict[str, Any]: | |||||
| if isinstance(data, tuple): | |||||
| sentence1, sentence2 = data | |||||
| elif isinstance(data, dict): | |||||
| sentence1 = data.get(self.first_sequence) | |||||
| sentence2 = data.get(self.second_sequence) | |||||
| if isinstance(sentence2, str): | |||||
| sentence2 = [sentence2] | |||||
| if isinstance(sentence1, str): | |||||
| sentence1 = [sentence1] | |||||
| sentence1 = sentence1 * len(sentence2) | |||||
| max_seq_length = self.sequence_length | |||||
| feature = self.tokenizer( | |||||
| sentence1, | |||||
| sentence2, | |||||
| padding='max_length', | |||||
| truncation=True, | |||||
| max_length=max_seq_length, | |||||
| return_tensors='pt') | |||||
| if 'labels' in data: | |||||
| labels = data['labels'] | |||||
| feature['labels'] = labels | |||||
| if 'qid' in data: | |||||
| qid = data['qid'] | |||||
| feature['qid'] = qid | |||||
| return feature | |||||
| @PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
| Fields.nlp, module_name=Preprocessors.nli_tokenizer) | Fields.nlp, module_name=Preprocessors.nli_tokenizer) | ||||
| @PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
| @@ -298,6 +355,51 @@ class SingleSentenceClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||||
| super().__init__(model_dir, pair=False, mode=mode, **kwargs) | super().__init__(model_dir, pair=False, mode=mode, **kwargs) | ||||
| @PREPROCESSORS.register_module( | |||||
| Fields.nlp, module_name=Preprocessors.sentence_embedding) | |||||
| class SentenceEmbeddingPreprocessor(NLPTokenizerPreprocessorBase): | |||||
| """The tokenizer preprocessor used in sentence embedding. | |||||
| """ | |||||
| def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): | |||||
| kwargs['truncation'] = kwargs.get('truncation', True) | |||||
| kwargs['padding'] = kwargs.get( | |||||
| 'padding', False if mode == ModeKeys.INFERENCE else 'max_length') | |||||
| kwargs['max_length'] = kwargs.pop('sequence_length', 128) | |||||
| super().__init__(model_dir, pair=False, mode=mode, **kwargs) | |||||
| def __call__(self, data: Union[str, Dict]) -> Dict[str, Any]: | |||||
| """process the raw input data | |||||
| Args: | |||||
| data Dict: | |||||
| keys: "source_sentence" && "sentences_to_compare" | |||||
| values: list of sentences | |||||
| Example: | |||||
| {"source_sentence": ["how long it take to get a master's degree"], | |||||
| "sentences_to_compare": ["On average, students take about 18 to 24 months | |||||
| to complete a master's degree.", | |||||
| "On the other hand, some students prefer to go at a slower pace | |||||
| and choose to take several years to complete their studies.", | |||||
| "It can take anywhere from two semesters"]} | |||||
| Returns: | |||||
| Dict[str, Any]: the preprocessed data | |||||
| """ | |||||
| source_sentence = data['source_sentence'] | |||||
| compare_sentences = data['sentences_to_compare'] | |||||
| sentences = [] | |||||
| sentences.append(source_sentence[0]) | |||||
| for sent in compare_sentences: | |||||
| sentences.append(sent) | |||||
| tokenized_inputs = self.tokenizer( | |||||
| sentences, | |||||
| return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None, | |||||
| padding=True, | |||||
| truncation=True) | |||||
| return tokenized_inputs | |||||
| @PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
| Fields.nlp, module_name=Preprocessors.zero_shot_cls_tokenizer) | Fields.nlp, module_name=Preprocessors.zero_shot_cls_tokenizer) | ||||
| class ZeroShotClassificationPreprocessor(NLPTokenizerPreprocessorBase): | class ZeroShotClassificationPreprocessor(NLPTokenizerPreprocessorBase): | ||||
| @@ -11,7 +11,7 @@ if TYPE_CHECKING: | |||||
| ImagePortraitEnhancementTrainer, | ImagePortraitEnhancementTrainer, | ||||
| MovieSceneSegmentationTrainer) | MovieSceneSegmentationTrainer) | ||||
| from .multi_modal import CLIPTrainer | from .multi_modal import CLIPTrainer | ||||
| from .nlp import SequenceClassificationTrainer | |||||
| from .nlp import SequenceClassificationTrainer, PassageRankingTrainer | |||||
| from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer | from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer | ||||
| from .trainer import EpochBasedTrainer | from .trainer import EpochBasedTrainer | ||||
| @@ -25,7 +25,7 @@ else: | |||||
| 'ImagePortraitEnhancementTrainer', 'MovieSceneSegmentationTrainer' | 'ImagePortraitEnhancementTrainer', 'MovieSceneSegmentationTrainer' | ||||
| ], | ], | ||||
| 'multi_modal': ['CLIPTrainer'], | 'multi_modal': ['CLIPTrainer'], | ||||
| 'nlp': ['SequenceClassificationTrainer'], | |||||
| 'nlp': ['SequenceClassificationTrainer', 'PassageRankingTrainer'], | |||||
| 'nlp_trainer': ['NlpEpochBasedTrainer', 'VecoTrainer'], | 'nlp_trainer': ['NlpEpochBasedTrainer', 'VecoTrainer'], | ||||
| 'trainer': ['EpochBasedTrainer'] | 'trainer': ['EpochBasedTrainer'] | ||||
| } | } | ||||
| @@ -6,10 +6,12 @@ from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from .sequence_classification_trainer import SequenceClassificationTrainer | from .sequence_classification_trainer import SequenceClassificationTrainer | ||||
| from .csanmt_translation_trainer import CsanmtTranslationTrainer | from .csanmt_translation_trainer import CsanmtTranslationTrainer | ||||
| from .passage_ranking_trainer import PassageRankingTranier | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'sequence_classification_trainer': ['SequenceClassificationTrainer'], | 'sequence_classification_trainer': ['SequenceClassificationTrainer'], | ||||
| 'csanmt_translation_trainer': ['CsanmtTranslationTrainer'], | 'csanmt_translation_trainer': ['CsanmtTranslationTrainer'], | ||||
| 'passage_ranking_trainer': ['PassageRankingTrainer'] | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,197 @@ | |||||
| import time | |||||
| from dataclasses import dataclass | |||||
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |||||
| import numpy as np | |||||
| import torch | |||||
| from torch import nn | |||||
| from torch.utils.data import DataLoader, Dataset | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.models.base import Model, TorchModel | |||||
| from modelscope.msdatasets.ms_dataset import MsDataset | |||||
| from modelscope.preprocessors.base import Preprocessor | |||||
| from modelscope.trainers.base import BaseTrainer | |||||
| from modelscope.trainers.builder import TRAINERS | |||||
| from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer | |||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @dataclass | |||||
| class GroupCollator(): | |||||
| """ | |||||
| Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg] | |||||
| and pass batch separately to the actual collator. | |||||
| Abstract out data detail for the model. | |||||
| """ | |||||
| def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: | |||||
| if isinstance(features[0], list): | |||||
| features = sum(features, []) | |||||
| keys = features[0].keys() | |||||
| batch = {k: list() for k in keys} | |||||
| for ele in features: | |||||
| for k, v in ele.items(): | |||||
| batch[k].append(v) | |||||
| batch = {k: torch.cat(v, dim=0) for k, v in batch.items()} | |||||
| return batch | |||||
| @TRAINERS.register_module(module_name=Trainers.nlp_passage_ranking_trainer) | |||||
| class PassageRankingTrainer(NlpEpochBasedTrainer): | |||||
| def __init__( | |||||
| self, | |||||
| model: Optional[Union[TorchModel, nn.Module, str]] = None, | |||||
| cfg_file: Optional[str] = None, | |||||
| cfg_modify_fn: Optional[Callable] = None, | |||||
| arg_parse_fn: Optional[Callable] = None, | |||||
| data_collator: Optional[Callable] = None, | |||||
| train_dataset: Optional[Union[MsDataset, Dataset]] = None, | |||||
| eval_dataset: Optional[Union[MsDataset, Dataset]] = None, | |||||
| preprocessor: Optional[Preprocessor] = None, | |||||
| optimizers: Tuple[torch.optim.Optimizer, | |||||
| torch.optim.lr_scheduler._LRScheduler] = (None, | |||||
| None), | |||||
| model_revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||||
| **kwargs): | |||||
| if data_collator is None: | |||||
| data_collator = GroupCollator() | |||||
| super().__init__( | |||||
| model=model, | |||||
| cfg_file=cfg_file, | |||||
| cfg_modify_fn=cfg_modify_fn, | |||||
| arg_parse_fn=arg_parse_fn, | |||||
| data_collator=data_collator, | |||||
| preprocessor=preprocessor, | |||||
| optimizers=optimizers, | |||||
| train_dataset=train_dataset, | |||||
| eval_dataset=eval_dataset, | |||||
| model_revision=model_revision, | |||||
| **kwargs) | |||||
| def compute_mrr(self, result, k=10): | |||||
| mrr = 0 | |||||
| for res in result.values(): | |||||
| sorted_res = sorted(res, key=lambda x: x[0], reverse=True) | |||||
| ar = 0 | |||||
| for index, ele in enumerate(sorted_res[:k]): | |||||
| if str(ele[1]) == '1': | |||||
| ar = 1.0 / (index + 1) | |||||
| break | |||||
| mrr += ar | |||||
| return mrr / len(result) | |||||
| def compute_ndcg(self, result, k=10): | |||||
| ndcg = 0 | |||||
| from sklearn import ndcg_score | |||||
| for res in result.values(): | |||||
| sorted_res = sorted(res, key=lambda x: [0], reverse=True) | |||||
| labels = np.array([[ele[1] for ele in sorted_res]]) | |||||
| scores = np.array([[ele[0] for ele in sorted_res]]) | |||||
| ndcg += float(ndcg_score(labels, scores, k=k)) | |||||
| ndcg = ndcg / len(result) | |||||
| return ndcg | |||||
| def evaluate(self, | |||||
| checkpoint_path: Optional[str] = None, | |||||
| *args, | |||||
| **kwargs) -> Dict[str, float]: | |||||
| """evaluate a dataset | |||||
| evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path` | |||||
| does not exist, read from the config file. | |||||
| Args: | |||||
| checkpoint_path (Optional[str], optional): the model path. Defaults to None. | |||||
| Returns: | |||||
| Dict[str, float]: the results about the evaluation | |||||
| Example: | |||||
| {"accuracy": 0.5091743119266054, "f1": 0.673780487804878} | |||||
| """ | |||||
| from modelscope.models.nlp import PassageRanking | |||||
| # get the raw online dataset | |||||
| self.eval_dataloader = self._build_dataloader_with_dataset( | |||||
| self.eval_dataset, | |||||
| **self.cfg.evaluation.get('dataloader', {}), | |||||
| collate_fn=self.eval_data_collator) | |||||
| # generate a standard dataloader | |||||
| # generate a model | |||||
| if checkpoint_path is not None: | |||||
| model = PassageRanking.from_pretrained(checkpoint_path) | |||||
| else: | |||||
| model = self.model | |||||
| # copy from easynlp (start) | |||||
| model.eval() | |||||
| total_samples = 0 | |||||
| logits_list = list() | |||||
| label_list = list() | |||||
| qid_list = list() | |||||
| total_spent_time = 0.0 | |||||
| device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |||||
| model.to(device) | |||||
| for _step, batch in enumerate(self.eval_dataloader): | |||||
| try: | |||||
| batch = { | |||||
| key: | |||||
| val.to(device) if isinstance(val, torch.Tensor) else val | |||||
| for key, val in batch.items() | |||||
| } | |||||
| except RuntimeError: | |||||
| batch = {key: val for key, val in batch.items()} | |||||
| infer_start_time = time.time() | |||||
| with torch.no_grad(): | |||||
| label_ids = batch.pop('labels').detach().cpu().numpy() | |||||
| qids = batch.pop('qid').detach().cpu().numpy() | |||||
| outputs = model(batch) | |||||
| infer_end_time = time.time() | |||||
| total_spent_time += infer_end_time - infer_start_time | |||||
| total_samples += self.eval_dataloader.batch_size | |||||
| assert 'scores' in outputs | |||||
| logits = outputs['scores'] | |||||
| label_list.extend(label_ids) | |||||
| logits_list.extend(logits) | |||||
| qid_list.extend(qids) | |||||
| logger.info('Inference time = {:.2f}s, [{:.4f} ms / sample] '.format( | |||||
| total_spent_time, total_spent_time * 1000 / total_samples)) | |||||
| rank_result = {} | |||||
| for qid, score, label in zip(qid_list, logits_list, label_list): | |||||
| if qid not in rank_result: | |||||
| rank_result[qid] = [] | |||||
| rank_result[qid].append((score, label)) | |||||
| for qid in rank_result: | |||||
| rank_result[qid] = sorted(rank_result[qid], key=lambda x: x[0]) | |||||
| eval_outputs = list() | |||||
| for metric in self.metrics: | |||||
| if metric.startswith('mrr'): | |||||
| k = metric.split('@')[-1] | |||||
| k = int(k) | |||||
| mrr = self.compute_mrr(rank_result, k=k) | |||||
| logger.info('{}: {}'.format(metric, mrr)) | |||||
| eval_outputs.append((metric, mrr)) | |||||
| elif metric.startswith('ndcg'): | |||||
| k = metric.split('@')[-1] | |||||
| k = int(k) | |||||
| ndcg = self.compute_ndcg(rank_result, k=k) | |||||
| logger.info('{}: {}'.format(metric, ndcg)) | |||||
| eval_outputs.append(('ndcg', ndcg)) | |||||
| else: | |||||
| raise NotImplementedError('Metric %s not implemented' % metric) | |||||
| return dict(eval_outputs) | |||||
| @@ -345,12 +345,12 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| type=self.cfg.task, mode=mode, datasets=datasets) | type=self.cfg.task, mode=mode, datasets=datasets) | ||||
| return build_task_dataset(cfg, self.cfg.task) | return build_task_dataset(cfg, self.cfg.task) | ||||
| else: | else: | ||||
| task_data_config.update( | |||||
| dict( | |||||
| mode=mode, | |||||
| datasets=datasets, | |||||
| preprocessor=preprocessor)) | |||||
| return build_task_dataset(task_data_config, self.cfg.task) | |||||
| # avoid add no str value datasets, preprocessors in cfg | |||||
| task_data_build_config = ConfigDict( | |||||
| mode=mode, datasets=datasets, preprocessor=preprocessor) | |||||
| task_data_build_config.update(task_data_config) | |||||
| return build_task_dataset(task_data_build_config, | |||||
| self.cfg.task) | |||||
| except Exception: | except Exception: | ||||
| if isinstance(datasets, (List, Tuple)) or preprocessor is not None: | if isinstance(datasets, (List, Tuple)) or preprocessor is not None: | ||||
| return TorchTaskDataset( | return TorchTaskDataset( | ||||
| @@ -89,6 +89,8 @@ class NLPTasks(object): | |||||
| sentiment_analysis = 'sentiment-analysis' | sentiment_analysis = 'sentiment-analysis' | ||||
| sentence_similarity = 'sentence-similarity' | sentence_similarity = 'sentence-similarity' | ||||
| text_classification = 'text-classification' | text_classification = 'text-classification' | ||||
| sentence_embedding = 'sentence-embedding' | |||||
| passage_ranking = 'passage-ranking' | |||||
| relation_extraction = 'relation-extraction' | relation_extraction = 'relation-extraction' | ||||
| zero_shot = 'zero-shot' | zero_shot = 'zero-shot' | ||||
| translation = 'translation' | translation = 'translation' | ||||
| @@ -0,0 +1,61 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import shutil | |||||
| import unittest | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.nlp import PassageRanking | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.pipelines.nlp import PassageRankingPipeline | |||||
| from modelscope.preprocessors import PassageRankingPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class PassageRankingTest(unittest.TestCase): | |||||
| model_id = 'damo/nlp_corom_passage-ranking_english-base' | |||||
| inputs = { | |||||
| 'source_sentence': ["how long it take to get a master's degree"], | |||||
| 'sentences_to_compare': [ | |||||
| "On average, students take about 18 to 24 months to complete a master's degree.", | |||||
| 'On the other hand, some students prefer to go at a slower pace and choose to take ' | |||||
| 'several years to complete their studies.', | |||||
| 'It can take anywhere from two semesters' | |||||
| ] | |||||
| } | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_by_direct_model_download(self): | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| tokenizer = PassageRankingPreprocessor(cache_path) | |||||
| model = PassageRanking.from_pretrained(cache_path) | |||||
| pipeline1 = PassageRankingPipeline(model, preprocessor=tokenizer) | |||||
| pipeline2 = pipeline( | |||||
| Tasks.passage_ranking, model=model, preprocessor=tokenizer) | |||||
| print(f'sentence: {self.inputs}\n' | |||||
| f'pipeline1:{pipeline1(input=self.inputs)}') | |||||
| print() | |||||
| print(f'pipeline2: {pipeline2(input=self.inputs)}') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| tokenizer = PassageRankingPreprocessor(model.model_dir) | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.passage_ranking, model=model, preprocessor=tokenizer) | |||||
| print(pipeline_ins(input=self.inputs)) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_with_model_name(self): | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.passage_ranking, model=self.model_id) | |||||
| print(pipeline_ins(input=self.inputs)) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_default_model(self): | |||||
| pipeline_ins = pipeline(task=Tasks.passage_ranking) | |||||
| print(pipeline_ins(input=self.inputs)) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -0,0 +1,82 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import shutil | |||||
| import unittest | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.nlp import SentenceEmbedding | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.pipelines.nlp import SentenceEmbeddingPipeline | |||||
| from modelscope.preprocessors import SentenceEmbeddingPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class SentenceEmbeddingTest(unittest.TestCase): | |||||
| model_id = 'damo/nlp_corom_sentence-embedding_english-base' | |||||
| inputs = { | |||||
| 'source_sentence': ["how long it take to get a master's degree"], | |||||
| 'sentences_to_compare': [ | |||||
| "On average, students take about 18 to 24 months to complete a master's degree.", | |||||
| 'On the other hand, some students prefer to go at a slower pace and choose to take ', | |||||
| 'several years to complete their studies.', | |||||
| 'It can take anywhere from two semesters' | |||||
| ] | |||||
| } | |||||
| inputs2 = { | |||||
| 'source_sentence': ["how long it take to get a master's degree"], | |||||
| 'sentences_to_compare': [ | |||||
| "On average, students take about 18 to 24 months to complete a master's degree." | |||||
| ] | |||||
| } | |||||
| inputs3 = { | |||||
| 'source_sentence': ["how long it take to get a master's degree"], | |||||
| 'sentences_to_compare': [] | |||||
| } | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_by_direct_model_download(self): | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| tokenizer = SentenceEmbeddingPreprocessor(cache_path) | |||||
| model = SentenceEmbedding.from_pretrained(cache_path) | |||||
| pipeline1 = SentenceEmbeddingPipeline(model, preprocessor=tokenizer) | |||||
| pipeline2 = pipeline( | |||||
| Tasks.sentence_embedding, model=model, preprocessor=tokenizer) | |||||
| print(f'inputs: {self.inputs}\n' | |||||
| f'pipeline1:{pipeline1(input=self.inputs)}') | |||||
| print() | |||||
| print(f'pipeline2: {pipeline2(input=self.inputs)}') | |||||
| print() | |||||
| print(f'inputs: {self.inputs2}\n' | |||||
| f'pipeline1:{pipeline1(input=self.inputs2)}') | |||||
| print() | |||||
| print(f'pipeline2: {pipeline2(input=self.inputs2)}') | |||||
| print(f'inputs: {self.inputs3}\n' | |||||
| f'pipeline1:{pipeline1(input=self.inputs3)}') | |||||
| print() | |||||
| print(f'pipeline2: {pipeline2(input=self.inputs3)}') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| tokenizer = SentenceEmbeddingPreprocessor(model.model_dir) | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.sentence_embedding, model=model, preprocessor=tokenizer) | |||||
| print(pipeline_ins(input=self.inputs)) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_with_model_name(self): | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.sentence_embedding, model=self.model_id) | |||||
| print(pipeline_ins(input=self.inputs)) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_default_model(self): | |||||
| pipeline_ins = pipeline(task=Tasks.sentence_embedding) | |||||
| print(pipeline_ins(input=self.inputs)) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -0,0 +1,133 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import unittest | |||||
| from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union | |||||
| import torch | |||||
| from transformers.tokenization_utils_base import PreTrainedTokenizerBase | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.models import Model | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| class TestFinetuneSequenceClassification(unittest.TestCase): | |||||
| inputs = { | |||||
| 'source_sentence': ["how long it take to get a master's degree"], | |||||
| 'sentences_to_compare': [ | |||||
| "On average, students take about 18 to 24 months to complete a master's degree.", | |||||
| 'On the other hand, some students prefer to go at a slower pace and choose to take ' | |||||
| 'several years to complete their studies.', | |||||
| 'It can take anywhere from two semesters' | |||||
| ] | |||||
| } | |||||
| def setUp(self): | |||||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(self.tmp_dir): | |||||
| os.makedirs(self.tmp_dir) | |||||
| def tearDown(self): | |||||
| shutil.rmtree(self.tmp_dir) | |||||
| super().tearDown() | |||||
| def finetune(self, | |||||
| model_id, | |||||
| train_dataset, | |||||
| eval_dataset, | |||||
| name=Trainers.nlp_passage_ranking_trainer, | |||||
| cfg_modify_fn=None, | |||||
| **kwargs): | |||||
| kwargs = dict( | |||||
| model=model_id, | |||||
| train_dataset=train_dataset, | |||||
| eval_dataset=eval_dataset, | |||||
| work_dir=self.tmp_dir, | |||||
| cfg_modify_fn=cfg_modify_fn, | |||||
| **kwargs) | |||||
| os.environ['LOCAL_RANK'] = '0' | |||||
| trainer = build_trainer(name=name, default_args=kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
| def test_finetune_msmarco(self): | |||||
| def cfg_modify_fn(cfg): | |||||
| cfg.task = 'passage-ranking' | |||||
| cfg['preprocessor'] = {'type': 'passage-ranking'} | |||||
| cfg.train.optimizer.lr = 2e-5 | |||||
| cfg['dataset'] = { | |||||
| 'train': { | |||||
| 'type': 'bert', | |||||
| 'query_sequence': 'query', | |||||
| 'pos_sequence': 'positive_passages', | |||||
| 'neg_sequence': 'negative_passages', | |||||
| 'passage_text_fileds': ['title', 'text'], | |||||
| 'qid_field': 'query_id' | |||||
| }, | |||||
| 'val': { | |||||
| 'type': 'bert', | |||||
| 'query_sequence': 'query', | |||||
| 'pos_sequence': 'positive_passages', | |||||
| 'neg_sequence': 'negative_passages', | |||||
| 'passage_text_fileds': ['title', 'text'], | |||||
| 'qid_field': 'query_id' | |||||
| }, | |||||
| } | |||||
| cfg['train']['neg_samples'] = 4 | |||||
| cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30 | |||||
| cfg.train.max_epochs = 1 | |||||
| cfg.train.train_batch_size = 4 | |||||
| cfg.train.lr_scheduler = { | |||||
| 'type': 'LinearLR', | |||||
| 'start_factor': 1.0, | |||||
| 'end_factor': 0.0, | |||||
| 'options': { | |||||
| 'by_epoch': False | |||||
| } | |||||
| } | |||||
| cfg.train.hooks = [{ | |||||
| 'type': 'CheckpointHook', | |||||
| 'interval': 1 | |||||
| }, { | |||||
| 'type': 'TextLoggerHook', | |||||
| 'interval': 1 | |||||
| }, { | |||||
| 'type': 'IterTimerHook' | |||||
| }, { | |||||
| 'type': 'EvaluationHook', | |||||
| 'by_epoch': False, | |||||
| 'interval': 3000 | |||||
| }] | |||||
| return cfg | |||||
| # load dataset | |||||
| ds = MsDataset.load('passage-ranking-demo', 'zyznull') | |||||
| train_ds = ds['train'].to_hf_dataset() | |||||
| dev_ds = ds['train'].to_hf_dataset() | |||||
| self.finetune( | |||||
| model_id='damo/nlp_corom_passage-ranking_english-base', | |||||
| train_dataset=train_ds, | |||||
| eval_dataset=dev_ds, | |||||
| cfg_modify_fn=cfg_modify_fn) | |||||
| output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) | |||||
| self.pipeline_passage_ranking(output_dir) | |||||
| def pipeline_passage_ranking(self, model_dir): | |||||
| model = Model.from_pretrained(model_dir) | |||||
| pipeline_ins = pipeline(task=Tasks.passage_ranking, model=model) | |||||
| print(pipeline_ins(input=self.inputs)) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||