Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9856179master
| @@ -193,6 +193,8 @@ class Pipelines(object): | |||
| plug_generation = 'plug-generation' | |||
| faq_question_answering = 'faq-question-answering' | |||
| conversational_text_to_sql = 'conversational-text-to-sql' | |||
| sentence_embedding = 'sentence-embedding' | |||
| passage_ranking = 'passage-ranking' | |||
| relation_extraction = 'relation-extraction' | |||
| document_segmentation = 'document-segmentation' | |||
| @@ -245,6 +247,7 @@ class Trainers(object): | |||
| dialog_intent_trainer = 'dialog-intent-trainer' | |||
| nlp_base_trainer = 'nlp-base-trainer' | |||
| nlp_veco_trainer = 'nlp-veco-trainer' | |||
| nlp_passage_ranking_trainer = 'nlp-passage-ranking-trainer' | |||
| # audio trainers | |||
| speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | |||
| @@ -272,6 +275,7 @@ class Preprocessors(object): | |||
| # nlp preprocessor | |||
| sen_sim_tokenizer = 'sen-sim-tokenizer' | |||
| cross_encoder_tokenizer = 'cross-encoder-tokenizer' | |||
| bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' | |||
| text_gen_tokenizer = 'text-gen-tokenizer' | |||
| token_cls_tokenizer = 'token-cls-tokenizer' | |||
| @@ -284,6 +288,8 @@ class Preprocessors(object): | |||
| sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' | |||
| zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | |||
| text_error_correction = 'text-error-correction' | |||
| sentence_embedding = 'sentence-embedding' | |||
| passage_ranking = 'passage-ranking' | |||
| sequence_labeling_tokenizer = 'sequence-labeling-tokenizer' | |||
| word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' | |||
| fill_mask = 'fill-mask' | |||
| @@ -29,6 +29,8 @@ if TYPE_CHECKING: | |||
| SingleBackboneTaskModelBase, | |||
| TokenClassificationModel) | |||
| from .token_classification import SbertForTokenClassification | |||
| from .sentence_embedding import SentenceEmbedding | |||
| from .passage_ranking import PassageRanking | |||
| else: | |||
| _import_structure = { | |||
| @@ -62,6 +64,8 @@ else: | |||
| 'SingleBackboneTaskModelBase', 'TokenClassificationModel' | |||
| ], | |||
| 'token_classification': ['SbertForTokenClassification'], | |||
| 'sentence_embedding': ['SentenceEmbedding'], | |||
| 'passage_ranking': ['PassageRanking'], | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,78 @@ | |||
| from typing import Any, Dict | |||
| import numpy as np | |||
| import torch | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.nlp import SbertForSequenceClassification | |||
| from modelscope.models.nlp.structbert import SbertPreTrainedModel | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['PassageRanking'] | |||
| @MODELS.register_module(Tasks.passage_ranking, module_name=Models.bert) | |||
| class PassageRanking(SbertForSequenceClassification, SbertPreTrainedModel): | |||
| base_model_prefix: str = 'bert' | |||
| supports_gradient_checkpointing = True | |||
| _keys_to_ignore_on_load_missing = [r'position_ids'] | |||
| def __init__(self, config, model_dir, *args, **kwargs): | |||
| if hasattr(config, 'base_model_prefix'): | |||
| PassageRanking.base_model_prefix = config.base_model_prefix | |||
| super().__init__(config, model_dir) | |||
| self.train_batch_size = kwargs.get('train_batch_size', 4) | |||
| self.register_buffer( | |||
| 'target_label', | |||
| torch.zeros(self.train_batch_size, dtype=torch.long)) | |||
| def build_base_model(self): | |||
| from .structbert import SbertModel | |||
| return SbertModel(self.config, add_pooling_layer=True) | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
| outputs = self.base_model.forward(**input) | |||
| # backbone model should return pooled_output as its second output | |||
| pooled_output = outputs[1] | |||
| pooled_output = self.dropout(pooled_output) | |||
| logits = self.classifier(pooled_output) | |||
| if self.base_model.training: | |||
| scores = logits.view(self.train_batch_size, -1) | |||
| loss_fct = torch.nn.CrossEntropyLoss() | |||
| loss = loss_fct(scores, self.target_label) | |||
| return {OutputKeys.LOGITS: logits, OutputKeys.LOSS: loss} | |||
| return {OutputKeys.LOGITS: logits} | |||
| def sigmoid(self, logits): | |||
| return np.exp(logits) / (1 + np.exp(logits)) | |||
| def postprocess(self, inputs: Dict[str, np.ndarray], | |||
| **kwargs) -> Dict[str, np.ndarray]: | |||
| logits = inputs['logits'].squeeze(-1).detach().cpu().numpy() | |||
| logits = self.sigmoid(logits).tolist() | |||
| result = {OutputKeys.SCORES: logits} | |||
| return result | |||
| @classmethod | |||
| def _instantiate(cls, **kwargs): | |||
| """Instantiate the model. | |||
| @param kwargs: Input args. | |||
| model_dir: The model dir used to load the checkpoint and the label information. | |||
| num_labels: An optional arg to tell the model how many classes to initialize. | |||
| Method will call utils.parse_label_mapping if num_labels not supplied. | |||
| If num_labels is not found, the model will use the default setting (1 classes). | |||
| @return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained | |||
| """ | |||
| num_labels = kwargs.get('num_labels', 1) | |||
| model_args = {} if num_labels is None else {'num_labels': num_labels} | |||
| return super(SbertPreTrainedModel, PassageRanking).from_pretrained( | |||
| pretrained_model_name_or_path=kwargs.get('model_dir'), | |||
| model_dir=kwargs.get('model_dir'), | |||
| **model_args) | |||
| @@ -0,0 +1,74 @@ | |||
| import os | |||
| from typing import Any, Dict | |||
| import json | |||
| import numpy as np | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.nlp.structbert import SbertPreTrainedModel | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['SentenceEmbedding'] | |||
| @MODELS.register_module(Tasks.sentence_embedding, module_name=Models.bert) | |||
| class SentenceEmbedding(TorchModel, SbertPreTrainedModel): | |||
| base_model_prefix: str = 'bert' | |||
| supports_gradient_checkpointing = True | |||
| _keys_to_ignore_on_load_missing = [r'position_ids'] | |||
| def __init__(self, config, model_dir): | |||
| super().__init__(model_dir) | |||
| self.config = config | |||
| setattr(self, self.base_model_prefix, self.build_base_model()) | |||
| def build_base_model(self): | |||
| from .structbert import SbertModel | |||
| return SbertModel(self.config, add_pooling_layer=False) | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Any]): the preprocessed data | |||
| Returns: | |||
| Dict[str, np.ndarray]: results | |||
| Example: | |||
| { | |||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||
| } | |||
| """ | |||
| return self.base_model(**input) | |||
| def postprocess(self, inputs: Dict[str, np.ndarray], | |||
| **kwargs) -> Dict[str, np.ndarray]: | |||
| embs = inputs['last_hidden_state'][:, 0].cpu().numpy() | |||
| num_sent = embs.shape[0] | |||
| if num_sent >= 2: | |||
| scores = np.dot(embs[0:1, ], np.transpose(embs[1:, ], | |||
| (1, 0))).tolist()[0] | |||
| else: | |||
| scores = [] | |||
| result = {'text_embedding': embs, 'scores': scores} | |||
| return result | |||
| @classmethod | |||
| def _instantiate(cls, **kwargs): | |||
| """Instantiate the model. | |||
| @param kwargs: Input args. | |||
| model_dir: The model dir used to load the checkpoint and the label information. | |||
| @return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained | |||
| """ | |||
| model_args = {} | |||
| return super(SbertPreTrainedModel, SentenceEmbedding).from_pretrained( | |||
| pretrained_model_name_or_path=kwargs.get('model_dir'), | |||
| model_dir=kwargs.get('model_dir'), | |||
| **model_args) | |||
| @@ -11,12 +11,14 @@ if TYPE_CHECKING: | |||
| from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset | |||
| from .movie_scene_segmentation import MovieSceneSegmentationDataset | |||
| from .video_summarization_dataset import VideoSummarizationDataset | |||
| from .passage_ranking_dataset import PassageRankingDataset | |||
| else: | |||
| _import_structure = { | |||
| 'base': ['TaskDataset'], | |||
| 'builder': ['TASK_DATASETS', 'build_task_dataset'], | |||
| 'torch_base_dataset': ['TorchTaskDataset'], | |||
| 'passage_ranking_dataset': ['PassageRankingDataset'], | |||
| 'veco_dataset': ['VecoDataset'], | |||
| 'image_instance_segmentation_coco_dataset': | |||
| ['ImageInstanceSegmentationCocoDataset'], | |||
| @@ -0,0 +1,151 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import random | |||
| from dataclasses import dataclass | |||
| from typing import Any, Dict, List, Tuple, Union | |||
| import torch | |||
| from datasets import Dataset, IterableDataset, concatenate_datasets | |||
| from torch.utils.data import ConcatDataset | |||
| from transformers import DataCollatorWithPadding | |||
| from modelscope.metainfo import Models | |||
| from modelscope.utils.constant import ModeKeys, Tasks | |||
| from .base import TaskDataset | |||
| from .builder import TASK_DATASETS | |||
| from .torch_base_dataset import TorchTaskDataset | |||
| @TASK_DATASETS.register_module( | |||
| group_key=Tasks.passage_ranking, module_name=Models.bert) | |||
| class PassageRankingDataset(TorchTaskDataset): | |||
| def __init__(self, | |||
| datasets: Union[Any, List[Any]], | |||
| mode, | |||
| preprocessor=None, | |||
| *args, | |||
| **kwargs): | |||
| self.seed = kwargs.get('seed', 42) | |||
| self.permutation = None | |||
| self.datasets = None | |||
| self.dataset_config = kwargs | |||
| self.query_sequence = self.dataset_config.get('query_sequence', | |||
| 'query') | |||
| self.pos_sequence = self.dataset_config.get('pos_sequence', | |||
| 'positive_passages') | |||
| self.neg_sequence = self.dataset_config.get('neg_sequence', | |||
| 'negative_passages') | |||
| self.passage_text_fileds = self.dataset_config.get( | |||
| 'passage_text_fileds', ['title', 'text']) | |||
| self.qid_field = self.dataset_config.get('qid_field', 'query_id') | |||
| if mode == ModeKeys.TRAIN: | |||
| train_config = kwargs.get('train', {}) | |||
| self.neg_samples = train_config.get('neg_samples', 4) | |||
| super().__init__(datasets, mode, preprocessor, **kwargs) | |||
| def __getitem__(self, index) -> Any: | |||
| if self.mode == ModeKeys.TRAIN: | |||
| return self.__get_train_item__(index) | |||
| else: | |||
| return self.__get_test_item__(index) | |||
| def __get_test_item__(self, index): | |||
| group = self._inner_dataset[index] | |||
| labels = [] | |||
| qry = group[self.query_sequence] | |||
| pos_sequences = group[self.pos_sequence] | |||
| pos_sequences = [ | |||
| ' '.join([ele[key] for key in self.passage_text_fileds]) | |||
| for ele in pos_sequences | |||
| ] | |||
| labels.extend([1] * len(pos_sequences)) | |||
| neg_sequences = group[self.neg_sequence] | |||
| neg_sequences = [ | |||
| ' '.join([ele[key] for key in self.passage_text_fileds]) | |||
| for ele in neg_sequences | |||
| ] | |||
| labels.extend([0] * len(neg_sequences)) | |||
| qid = group[self.qid_field] | |||
| examples = pos_sequences + neg_sequences | |||
| sample = { | |||
| 'qid': torch.LongTensor([int(qid)] * len(labels)), | |||
| self.preprocessor.first_sequence: qry, | |||
| self.preprocessor.second_sequence: examples, | |||
| 'labels': torch.LongTensor(labels) | |||
| } | |||
| return self.prepare_sample(sample) | |||
| def __get_train_item__(self, index): | |||
| group = self._inner_dataset[index] | |||
| qry = group[self.query_sequence] | |||
| pos_sequences = group[self.pos_sequence] | |||
| pos_sequences = [ | |||
| ' '.join([ele[key] for key in self.passage_text_fileds]) | |||
| for ele in pos_sequences | |||
| ] | |||
| neg_sequences = group[self.neg_sequence] | |||
| neg_sequences = [ | |||
| ' '.join([ele[key] for key in self.passage_text_fileds]) | |||
| for ele in neg_sequences | |||
| ] | |||
| pos_psg = random.choice(pos_sequences) | |||
| if len(neg_sequences) < self.neg_samples: | |||
| negs = random.choices(neg_sequences, k=self.neg_samples) | |||
| else: | |||
| negs = random.sample(neg_sequences, k=self.neg_samples) | |||
| examples = [pos_psg] + negs | |||
| sample = { | |||
| self.preprocessor.first_sequence: qry, | |||
| self.preprocessor.second_sequence: examples, | |||
| } | |||
| return self.prepare_sample(sample) | |||
| def __len__(self): | |||
| return len(self._inner_dataset) | |||
| def prepare_dataset(self, datasets: Union[Any, List[Any]]) -> Any: | |||
| """Prepare a dataset. | |||
| User can process the input datasets in a whole dataset perspective. | |||
| This method gives a default implementation of datasets merging, user can override this | |||
| method to write custom logics. | |||
| Args: | |||
| datasets: The original dataset(s) | |||
| Returns: A single dataset, which may be created after merging. | |||
| """ | |||
| if isinstance(datasets, List): | |||
| if len(datasets) == 1: | |||
| return datasets[0] | |||
| elif len(datasets) > 1: | |||
| return ConcatDataset(datasets) | |||
| else: | |||
| return datasets | |||
| def prepare_sample(self, data): | |||
| """Preprocess the data fetched from the inner_dataset. | |||
| If the preprocessor is None, the original data will be returned, else the preprocessor will be called. | |||
| User can override this method to implement custom logics. | |||
| Args: | |||
| data: The data fetched from the dataset. | |||
| Returns: The processed data. | |||
| """ | |||
| return self.preprocessor( | |||
| data) if self.preprocessor is not None else data | |||
| @@ -387,19 +387,14 @@ TASK_OUTPUTS = { | |||
| # "output": "我想吃苹果" | |||
| # } | |||
| Tasks.text_error_correction: [OutputKeys.OUTPUT], | |||
| Tasks.sentence_embedding: [OutputKeys.TEXT_EMBEDDING, OutputKeys.SCORES], | |||
| Tasks.passage_ranking: [OutputKeys.SCORES], | |||
| # text generation result for single sample | |||
| # { | |||
| # "text": "this is the text generated by a model." | |||
| # } | |||
| Tasks.text_generation: [OutputKeys.TEXT], | |||
| # text feature extraction for single sample | |||
| # { | |||
| # "text_embedding": np.array with shape [1, D] | |||
| # } | |||
| Tasks.sentence_embedding: [OutputKeys.TEXT_EMBEDDING], | |||
| # fill mask result for single sample | |||
| # { | |||
| # "text": "this is the text which masks filled by model." | |||
| @@ -17,6 +17,11 @@ PIPELINES = Registry('pipelines') | |||
| DEFAULT_MODEL_FOR_PIPELINE = { | |||
| # TaskName: (pipeline_module_name, model_repo) | |||
| Tasks.sentence_embedding: | |||
| (Pipelines.sentence_embedding, | |||
| 'damo/nlp_corom_sentence-embedding_english-base'), | |||
| Tasks.passage_ranking: (Pipelines.passage_ranking, | |||
| 'damo/nlp_corom_passage-ranking_english-base'), | |||
| Tasks.word_segmentation: | |||
| (Pipelines.word_segmentation, | |||
| 'damo/nlp_structbert_word-segmentation_chinese-base'), | |||
| @@ -25,7 +25,8 @@ if TYPE_CHECKING: | |||
| from .translation_pipeline import TranslationPipeline | |||
| from .word_segmentation_pipeline import WordSegmentationPipeline | |||
| from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline | |||
| from .passage_ranking_pipeline import PassageRankingPipeline | |||
| from .sentence_embedding_pipeline import SentenceEmbeddingPipeline | |||
| else: | |||
| _import_structure = { | |||
| 'conversational_text_to_sql_pipeline': | |||
| @@ -55,6 +56,8 @@ else: | |||
| 'word_segmentation_pipeline': ['WordSegmentationPipeline'], | |||
| 'zero_shot_classification_pipeline': | |||
| ['ZeroShotClassificationPipeline'], | |||
| 'passage_ranking_pipeline': ['PassageRankingPipeline'], | |||
| 'sentence_embedding_pipeline': ['SentenceEmbeddingPipeline'] | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,58 @@ | |||
| from typing import Any, Dict, Optional, Union | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models import Model | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import PassageRankingPreprocessor, Preprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['PassageRankingPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.passage_ranking, module_name=Pipelines.passage_ranking) | |||
| class PassageRankingPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[Model, str], | |||
| preprocessor: Optional[Preprocessor] = None, | |||
| **kwargs): | |||
| """Use `model` and `preprocessor` to create a nlp word segment pipeline for prediction. | |||
| Args: | |||
| model (str or Model): Supply either a local model dir which supported the WS task, | |||
| or a model id from the model hub, or a torch model instance. | |||
| preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for | |||
| the model if supplied. | |||
| sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value. | |||
| """ | |||
| model = model if isinstance(model, | |||
| Model) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = PassageRankingPreprocessor( | |||
| model.model_dir if isinstance(model, Model) else model, | |||
| sequence_length=kwargs.pop('sequence_length', 128)) | |||
| model.eval() | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return {**self.model(inputs, **forward_params)} | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, Any]: the predicted text representation | |||
| """ | |||
| pred_list = inputs[OutputKeys.SCORES] | |||
| return {OutputKeys.SCORES: pred_list} | |||
| @@ -0,0 +1,60 @@ | |||
| from typing import Any, Dict, Optional, Union | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models import Model | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import (Preprocessor, | |||
| SentenceEmbeddingPreprocessor) | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['SentenceEmbeddingPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.sentence_embedding, module_name=Pipelines.sentence_embedding) | |||
| class SentenceEmbeddingPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[Model, str], | |||
| preprocessor: Optional[Preprocessor] = None, | |||
| first_sequence='first_sequence', | |||
| **kwargs): | |||
| """Use `model` and `preprocessor` to create a nlp text dual encoder then generates the text representation. | |||
| Args: | |||
| model (str or Model): Supply either a local model dir which supported the WS task, | |||
| or a model id from the model hub, or a torch model instance. | |||
| preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for | |||
| the model if supplied. | |||
| sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value. | |||
| """ | |||
| model = model if isinstance(model, | |||
| Model) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = SentenceEmbeddingPreprocessor( | |||
| model.model_dir if isinstance(model, Model) else model, | |||
| first_sequence=first_sequence, | |||
| sequence_length=kwargs.pop('sequence_length', 128)) | |||
| model.eval() | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return {**self.model(inputs, **forward_params)} | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, Any]: the predicted text representation | |||
| """ | |||
| embs = inputs[OutputKeys.TEXT_EMBEDDING] | |||
| scores = inputs[OutputKeys.SCORES] | |||
| return {OutputKeys.TEXT_EMBEDDING: embs, OutputKeys.SCORES: scores} | |||
| @@ -23,7 +23,8 @@ if TYPE_CHECKING: | |||
| ZeroShotClassificationPreprocessor, NERPreprocessor, | |||
| TextErrorCorrectionPreprocessor, FaqQuestionAnsweringPreprocessor, | |||
| SequenceLabelingPreprocessor, RelationExtractionPreprocessor, | |||
| DocumentSegmentationPreprocessor, FillMaskPoNetPreprocessor) | |||
| DocumentSegmentationPreprocessor, FillMaskPoNetPreprocessor, | |||
| PassageRankingPreprocessor) | |||
| from .space import (DialogIntentPredictionPreprocessor, | |||
| DialogModelingPreprocessor, | |||
| DialogStateTrackingPreprocessor) | |||
| @@ -50,6 +51,7 @@ else: | |||
| 'SingleSentenceClassificationPreprocessor', | |||
| 'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | |||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | |||
| 'SentenceEmbeddingPreprocessor', 'PassageRankingPreprocessor', | |||
| 'TextErrorCorrectionPreprocessor', | |||
| 'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor', | |||
| 'RelationExtractionPreprocessor', | |||
| @@ -29,6 +29,7 @@ __all__ = [ | |||
| 'PairSentenceClassificationPreprocessor', | |||
| 'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | |||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | |||
| 'SentenceEmbeddingPreprocessor', 'PassageRankingPreprocessor', | |||
| 'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor', | |||
| 'SequenceLabelingPreprocessor', 'RelationExtractionPreprocessor', | |||
| 'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor' | |||
| @@ -100,6 +101,7 @@ class SequenceClassificationPreprocessor(Preprocessor): | |||
| text_a = new_data[self.first_sequence] | |||
| text_b = new_data.get(self.second_sequence, None) | |||
| feature = self.tokenizer( | |||
| text_a, | |||
| text_b, | |||
| @@ -111,7 +113,6 @@ class SequenceClassificationPreprocessor(Preprocessor): | |||
| rst['input_ids'].append(feature['input_ids']) | |||
| rst['attention_mask'].append(feature['attention_mask']) | |||
| rst['token_type_ids'].append(feature['token_type_ids']) | |||
| return rst | |||
| @@ -268,6 +269,62 @@ class NLPTokenizerPreprocessorBase(Preprocessor): | |||
| output[OutputKeys.LABELS] = labels | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.passage_ranking) | |||
| class PassageRankingPreprocessor(NLPTokenizerPreprocessorBase): | |||
| """The tokenizer preprocessor used in passage ranking model. | |||
| """ | |||
| def __init__(self, | |||
| model_dir: str, | |||
| mode=ModeKeys.INFERENCE, | |||
| *args, | |||
| **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(model_dir, pair=True, mode=mode, *args, **kwargs) | |||
| self.model_dir: str = model_dir | |||
| self.first_sequence: str = kwargs.pop('first_sequence', | |||
| 'source_sentence') | |||
| self.second_sequence = kwargs.pop('second_sequence', | |||
| 'sentences_to_compare') | |||
| self.sequence_length = kwargs.pop('sequence_length', 128) | |||
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) | |||
| @type_assert(object, (str, tuple, Dict)) | |||
| def __call__(self, data: Union[tuple, Dict]) -> Dict[str, Any]: | |||
| if isinstance(data, tuple): | |||
| sentence1, sentence2 = data | |||
| elif isinstance(data, dict): | |||
| sentence1 = data.get(self.first_sequence) | |||
| sentence2 = data.get(self.second_sequence) | |||
| if isinstance(sentence2, str): | |||
| sentence2 = [sentence2] | |||
| if isinstance(sentence1, str): | |||
| sentence1 = [sentence1] | |||
| sentence1 = sentence1 * len(sentence2) | |||
| max_seq_length = self.sequence_length | |||
| feature = self.tokenizer( | |||
| sentence1, | |||
| sentence2, | |||
| padding='max_length', | |||
| truncation=True, | |||
| max_length=max_seq_length, | |||
| return_tensors='pt') | |||
| if 'labels' in data: | |||
| labels = data['labels'] | |||
| feature['labels'] = labels | |||
| if 'qid' in data: | |||
| qid = data['qid'] | |||
| feature['qid'] = qid | |||
| return feature | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.nli_tokenizer) | |||
| @PREPROCESSORS.register_module( | |||
| @@ -298,6 +355,51 @@ class SingleSentenceClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| super().__init__(model_dir, pair=False, mode=mode, **kwargs) | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.sentence_embedding) | |||
| class SentenceEmbeddingPreprocessor(NLPTokenizerPreprocessorBase): | |||
| """The tokenizer preprocessor used in sentence embedding. | |||
| """ | |||
| def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): | |||
| kwargs['truncation'] = kwargs.get('truncation', True) | |||
| kwargs['padding'] = kwargs.get( | |||
| 'padding', False if mode == ModeKeys.INFERENCE else 'max_length') | |||
| kwargs['max_length'] = kwargs.pop('sequence_length', 128) | |||
| super().__init__(model_dir, pair=False, mode=mode, **kwargs) | |||
| def __call__(self, data: Union[str, Dict]) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data Dict: | |||
| keys: "source_sentence" && "sentences_to_compare" | |||
| values: list of sentences | |||
| Example: | |||
| {"source_sentence": ["how long it take to get a master's degree"], | |||
| "sentences_to_compare": ["On average, students take about 18 to 24 months | |||
| to complete a master's degree.", | |||
| "On the other hand, some students prefer to go at a slower pace | |||
| and choose to take several years to complete their studies.", | |||
| "It can take anywhere from two semesters"]} | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| source_sentence = data['source_sentence'] | |||
| compare_sentences = data['sentences_to_compare'] | |||
| sentences = [] | |||
| sentences.append(source_sentence[0]) | |||
| for sent in compare_sentences: | |||
| sentences.append(sent) | |||
| tokenized_inputs = self.tokenizer( | |||
| sentences, | |||
| return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None, | |||
| padding=True, | |||
| truncation=True) | |||
| return tokenized_inputs | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.zero_shot_cls_tokenizer) | |||
| class ZeroShotClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| @@ -11,7 +11,7 @@ if TYPE_CHECKING: | |||
| ImagePortraitEnhancementTrainer, | |||
| MovieSceneSegmentationTrainer) | |||
| from .multi_modal import CLIPTrainer | |||
| from .nlp import SequenceClassificationTrainer | |||
| from .nlp import SequenceClassificationTrainer, PassageRankingTrainer | |||
| from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer | |||
| from .trainer import EpochBasedTrainer | |||
| @@ -25,7 +25,7 @@ else: | |||
| 'ImagePortraitEnhancementTrainer', 'MovieSceneSegmentationTrainer' | |||
| ], | |||
| 'multi_modal': ['CLIPTrainer'], | |||
| 'nlp': ['SequenceClassificationTrainer'], | |||
| 'nlp': ['SequenceClassificationTrainer', 'PassageRankingTrainer'], | |||
| 'nlp_trainer': ['NlpEpochBasedTrainer', 'VecoTrainer'], | |||
| 'trainer': ['EpochBasedTrainer'] | |||
| } | |||
| @@ -6,10 +6,12 @@ from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .sequence_classification_trainer import SequenceClassificationTrainer | |||
| from .csanmt_translation_trainer import CsanmtTranslationTrainer | |||
| from .passage_ranking_trainer import PassageRankingTranier | |||
| else: | |||
| _import_structure = { | |||
| 'sequence_classification_trainer': ['SequenceClassificationTrainer'], | |||
| 'csanmt_translation_trainer': ['CsanmtTranslationTrainer'], | |||
| 'passage_ranking_trainer': ['PassageRankingTrainer'] | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,197 @@ | |||
| import time | |||
| from dataclasses import dataclass | |||
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |||
| import numpy as np | |||
| import torch | |||
| from torch import nn | |||
| from torch.utils.data import DataLoader, Dataset | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.models.base import Model, TorchModel | |||
| from modelscope.msdatasets.ms_dataset import MsDataset | |||
| from modelscope.preprocessors.base import Preprocessor | |||
| from modelscope.trainers.base import BaseTrainer | |||
| from modelscope.trainers.builder import TRAINERS | |||
| from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer | |||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @dataclass | |||
| class GroupCollator(): | |||
| """ | |||
| Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg] | |||
| and pass batch separately to the actual collator. | |||
| Abstract out data detail for the model. | |||
| """ | |||
| def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: | |||
| if isinstance(features[0], list): | |||
| features = sum(features, []) | |||
| keys = features[0].keys() | |||
| batch = {k: list() for k in keys} | |||
| for ele in features: | |||
| for k, v in ele.items(): | |||
| batch[k].append(v) | |||
| batch = {k: torch.cat(v, dim=0) for k, v in batch.items()} | |||
| return batch | |||
| @TRAINERS.register_module(module_name=Trainers.nlp_passage_ranking_trainer) | |||
| class PassageRankingTrainer(NlpEpochBasedTrainer): | |||
| def __init__( | |||
| self, | |||
| model: Optional[Union[TorchModel, nn.Module, str]] = None, | |||
| cfg_file: Optional[str] = None, | |||
| cfg_modify_fn: Optional[Callable] = None, | |||
| arg_parse_fn: Optional[Callable] = None, | |||
| data_collator: Optional[Callable] = None, | |||
| train_dataset: Optional[Union[MsDataset, Dataset]] = None, | |||
| eval_dataset: Optional[Union[MsDataset, Dataset]] = None, | |||
| preprocessor: Optional[Preprocessor] = None, | |||
| optimizers: Tuple[torch.optim.Optimizer, | |||
| torch.optim.lr_scheduler._LRScheduler] = (None, | |||
| None), | |||
| model_revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||
| **kwargs): | |||
| if data_collator is None: | |||
| data_collator = GroupCollator() | |||
| super().__init__( | |||
| model=model, | |||
| cfg_file=cfg_file, | |||
| cfg_modify_fn=cfg_modify_fn, | |||
| arg_parse_fn=arg_parse_fn, | |||
| data_collator=data_collator, | |||
| preprocessor=preprocessor, | |||
| optimizers=optimizers, | |||
| train_dataset=train_dataset, | |||
| eval_dataset=eval_dataset, | |||
| model_revision=model_revision, | |||
| **kwargs) | |||
| def compute_mrr(self, result, k=10): | |||
| mrr = 0 | |||
| for res in result.values(): | |||
| sorted_res = sorted(res, key=lambda x: x[0], reverse=True) | |||
| ar = 0 | |||
| for index, ele in enumerate(sorted_res[:k]): | |||
| if str(ele[1]) == '1': | |||
| ar = 1.0 / (index + 1) | |||
| break | |||
| mrr += ar | |||
| return mrr / len(result) | |||
| def compute_ndcg(self, result, k=10): | |||
| ndcg = 0 | |||
| from sklearn import ndcg_score | |||
| for res in result.values(): | |||
| sorted_res = sorted(res, key=lambda x: [0], reverse=True) | |||
| labels = np.array([[ele[1] for ele in sorted_res]]) | |||
| scores = np.array([[ele[0] for ele in sorted_res]]) | |||
| ndcg += float(ndcg_score(labels, scores, k=k)) | |||
| ndcg = ndcg / len(result) | |||
| return ndcg | |||
| def evaluate(self, | |||
| checkpoint_path: Optional[str] = None, | |||
| *args, | |||
| **kwargs) -> Dict[str, float]: | |||
| """evaluate a dataset | |||
| evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path` | |||
| does not exist, read from the config file. | |||
| Args: | |||
| checkpoint_path (Optional[str], optional): the model path. Defaults to None. | |||
| Returns: | |||
| Dict[str, float]: the results about the evaluation | |||
| Example: | |||
| {"accuracy": 0.5091743119266054, "f1": 0.673780487804878} | |||
| """ | |||
| from modelscope.models.nlp import PassageRanking | |||
| # get the raw online dataset | |||
| self.eval_dataloader = self._build_dataloader_with_dataset( | |||
| self.eval_dataset, | |||
| **self.cfg.evaluation.get('dataloader', {}), | |||
| collate_fn=self.eval_data_collator) | |||
| # generate a standard dataloader | |||
| # generate a model | |||
| if checkpoint_path is not None: | |||
| model = PassageRanking.from_pretrained(checkpoint_path) | |||
| else: | |||
| model = self.model | |||
| # copy from easynlp (start) | |||
| model.eval() | |||
| total_samples = 0 | |||
| logits_list = list() | |||
| label_list = list() | |||
| qid_list = list() | |||
| total_spent_time = 0.0 | |||
| device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |||
| model.to(device) | |||
| for _step, batch in enumerate(self.eval_dataloader): | |||
| try: | |||
| batch = { | |||
| key: | |||
| val.to(device) if isinstance(val, torch.Tensor) else val | |||
| for key, val in batch.items() | |||
| } | |||
| except RuntimeError: | |||
| batch = {key: val for key, val in batch.items()} | |||
| infer_start_time = time.time() | |||
| with torch.no_grad(): | |||
| label_ids = batch.pop('labels').detach().cpu().numpy() | |||
| qids = batch.pop('qid').detach().cpu().numpy() | |||
| outputs = model(batch) | |||
| infer_end_time = time.time() | |||
| total_spent_time += infer_end_time - infer_start_time | |||
| total_samples += self.eval_dataloader.batch_size | |||
| assert 'scores' in outputs | |||
| logits = outputs['scores'] | |||
| label_list.extend(label_ids) | |||
| logits_list.extend(logits) | |||
| qid_list.extend(qids) | |||
| logger.info('Inference time = {:.2f}s, [{:.4f} ms / sample] '.format( | |||
| total_spent_time, total_spent_time * 1000 / total_samples)) | |||
| rank_result = {} | |||
| for qid, score, label in zip(qid_list, logits_list, label_list): | |||
| if qid not in rank_result: | |||
| rank_result[qid] = [] | |||
| rank_result[qid].append((score, label)) | |||
| for qid in rank_result: | |||
| rank_result[qid] = sorted(rank_result[qid], key=lambda x: x[0]) | |||
| eval_outputs = list() | |||
| for metric in self.metrics: | |||
| if metric.startswith('mrr'): | |||
| k = metric.split('@')[-1] | |||
| k = int(k) | |||
| mrr = self.compute_mrr(rank_result, k=k) | |||
| logger.info('{}: {}'.format(metric, mrr)) | |||
| eval_outputs.append((metric, mrr)) | |||
| elif metric.startswith('ndcg'): | |||
| k = metric.split('@')[-1] | |||
| k = int(k) | |||
| ndcg = self.compute_ndcg(rank_result, k=k) | |||
| logger.info('{}: {}'.format(metric, ndcg)) | |||
| eval_outputs.append(('ndcg', ndcg)) | |||
| else: | |||
| raise NotImplementedError('Metric %s not implemented' % metric) | |||
| return dict(eval_outputs) | |||
| @@ -345,12 +345,12 @@ class EpochBasedTrainer(BaseTrainer): | |||
| type=self.cfg.task, mode=mode, datasets=datasets) | |||
| return build_task_dataset(cfg, self.cfg.task) | |||
| else: | |||
| task_data_config.update( | |||
| dict( | |||
| mode=mode, | |||
| datasets=datasets, | |||
| preprocessor=preprocessor)) | |||
| return build_task_dataset(task_data_config, self.cfg.task) | |||
| # avoid add no str value datasets, preprocessors in cfg | |||
| task_data_build_config = ConfigDict( | |||
| mode=mode, datasets=datasets, preprocessor=preprocessor) | |||
| task_data_build_config.update(task_data_config) | |||
| return build_task_dataset(task_data_build_config, | |||
| self.cfg.task) | |||
| except Exception: | |||
| if isinstance(datasets, (List, Tuple)) or preprocessor is not None: | |||
| return TorchTaskDataset( | |||
| @@ -89,6 +89,8 @@ class NLPTasks(object): | |||
| sentiment_analysis = 'sentiment-analysis' | |||
| sentence_similarity = 'sentence-similarity' | |||
| text_classification = 'text-classification' | |||
| sentence_embedding = 'sentence-embedding' | |||
| passage_ranking = 'passage-ranking' | |||
| relation_extraction = 'relation-extraction' | |||
| zero_shot = 'zero-shot' | |||
| translation = 'translation' | |||
| @@ -0,0 +1,61 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import shutil | |||
| import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import PassageRanking | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import PassageRankingPipeline | |||
| from modelscope.preprocessors import PassageRankingPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class PassageRankingTest(unittest.TestCase): | |||
| model_id = 'damo/nlp_corom_passage-ranking_english-base' | |||
| inputs = { | |||
| 'source_sentence': ["how long it take to get a master's degree"], | |||
| 'sentences_to_compare': [ | |||
| "On average, students take about 18 to 24 months to complete a master's degree.", | |||
| 'On the other hand, some students prefer to go at a slower pace and choose to take ' | |||
| 'several years to complete their studies.', | |||
| 'It can take anywhere from two semesters' | |||
| ] | |||
| } | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_by_direct_model_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| tokenizer = PassageRankingPreprocessor(cache_path) | |||
| model = PassageRanking.from_pretrained(cache_path) | |||
| pipeline1 = PassageRankingPipeline(model, preprocessor=tokenizer) | |||
| pipeline2 = pipeline( | |||
| Tasks.passage_ranking, model=model, preprocessor=tokenizer) | |||
| print(f'sentence: {self.inputs}\n' | |||
| f'pipeline1:{pipeline1(input=self.inputs)}') | |||
| print() | |||
| print(f'pipeline2: {pipeline2(input=self.inputs)}') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| tokenizer = PassageRankingPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.passage_ranking, model=model, preprocessor=tokenizer) | |||
| print(pipeline_ins(input=self.inputs)) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.passage_ranking, model=self.model_id) | |||
| print(pipeline_ins(input=self.inputs)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline(task=Tasks.passage_ranking) | |||
| print(pipeline_ins(input=self.inputs)) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,82 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import shutil | |||
| import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import SentenceEmbedding | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import SentenceEmbeddingPipeline | |||
| from modelscope.preprocessors import SentenceEmbeddingPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class SentenceEmbeddingTest(unittest.TestCase): | |||
| model_id = 'damo/nlp_corom_sentence-embedding_english-base' | |||
| inputs = { | |||
| 'source_sentence': ["how long it take to get a master's degree"], | |||
| 'sentences_to_compare': [ | |||
| "On average, students take about 18 to 24 months to complete a master's degree.", | |||
| 'On the other hand, some students prefer to go at a slower pace and choose to take ', | |||
| 'several years to complete their studies.', | |||
| 'It can take anywhere from two semesters' | |||
| ] | |||
| } | |||
| inputs2 = { | |||
| 'source_sentence': ["how long it take to get a master's degree"], | |||
| 'sentences_to_compare': [ | |||
| "On average, students take about 18 to 24 months to complete a master's degree." | |||
| ] | |||
| } | |||
| inputs3 = { | |||
| 'source_sentence': ["how long it take to get a master's degree"], | |||
| 'sentences_to_compare': [] | |||
| } | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_by_direct_model_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| tokenizer = SentenceEmbeddingPreprocessor(cache_path) | |||
| model = SentenceEmbedding.from_pretrained(cache_path) | |||
| pipeline1 = SentenceEmbeddingPipeline(model, preprocessor=tokenizer) | |||
| pipeline2 = pipeline( | |||
| Tasks.sentence_embedding, model=model, preprocessor=tokenizer) | |||
| print(f'inputs: {self.inputs}\n' | |||
| f'pipeline1:{pipeline1(input=self.inputs)}') | |||
| print() | |||
| print(f'pipeline2: {pipeline2(input=self.inputs)}') | |||
| print() | |||
| print(f'inputs: {self.inputs2}\n' | |||
| f'pipeline1:{pipeline1(input=self.inputs2)}') | |||
| print() | |||
| print(f'pipeline2: {pipeline2(input=self.inputs2)}') | |||
| print(f'inputs: {self.inputs3}\n' | |||
| f'pipeline1:{pipeline1(input=self.inputs3)}') | |||
| print() | |||
| print(f'pipeline2: {pipeline2(input=self.inputs3)}') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| tokenizer = SentenceEmbeddingPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.sentence_embedding, model=model, preprocessor=tokenizer) | |||
| print(pipeline_ins(input=self.inputs)) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.sentence_embedding, model=self.model_id) | |||
| print(pipeline_ins(input=self.inputs)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline(task=Tasks.sentence_embedding) | |||
| print(pipeline_ins(input=self.inputs)) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,133 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import shutil | |||
| import tempfile | |||
| import unittest | |||
| from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union | |||
| import torch | |||
| from transformers.tokenization_utils_base import PreTrainedTokenizerBase | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.models import Model | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| inputs = { | |||
| 'source_sentence': ["how long it take to get a master's degree"], | |||
| 'sentences_to_compare': [ | |||
| "On average, students take about 18 to 24 months to complete a master's degree.", | |||
| 'On the other hand, some students prefer to go at a slower pace and choose to take ' | |||
| 'several years to complete their studies.', | |||
| 'It can take anywhere from two semesters' | |||
| ] | |||
| } | |||
| def setUp(self): | |||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(self.tmp_dir): | |||
| os.makedirs(self.tmp_dir) | |||
| def tearDown(self): | |||
| shutil.rmtree(self.tmp_dir) | |||
| super().tearDown() | |||
| def finetune(self, | |||
| model_id, | |||
| train_dataset, | |||
| eval_dataset, | |||
| name=Trainers.nlp_passage_ranking_trainer, | |||
| cfg_modify_fn=None, | |||
| **kwargs): | |||
| kwargs = dict( | |||
| model=model_id, | |||
| train_dataset=train_dataset, | |||
| eval_dataset=eval_dataset, | |||
| work_dir=self.tmp_dir, | |||
| cfg_modify_fn=cfg_modify_fn, | |||
| **kwargs) | |||
| os.environ['LOCAL_RANK'] = '0' | |||
| trainer = build_trainer(name=name, default_args=kwargs) | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| def test_finetune_msmarco(self): | |||
| def cfg_modify_fn(cfg): | |||
| cfg.task = 'passage-ranking' | |||
| cfg['preprocessor'] = {'type': 'passage-ranking'} | |||
| cfg.train.optimizer.lr = 2e-5 | |||
| cfg['dataset'] = { | |||
| 'train': { | |||
| 'type': 'bert', | |||
| 'query_sequence': 'query', | |||
| 'pos_sequence': 'positive_passages', | |||
| 'neg_sequence': 'negative_passages', | |||
| 'passage_text_fileds': ['title', 'text'], | |||
| 'qid_field': 'query_id' | |||
| }, | |||
| 'val': { | |||
| 'type': 'bert', | |||
| 'query_sequence': 'query', | |||
| 'pos_sequence': 'positive_passages', | |||
| 'neg_sequence': 'negative_passages', | |||
| 'passage_text_fileds': ['title', 'text'], | |||
| 'qid_field': 'query_id' | |||
| }, | |||
| } | |||
| cfg['train']['neg_samples'] = 4 | |||
| cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30 | |||
| cfg.train.max_epochs = 1 | |||
| cfg.train.train_batch_size = 4 | |||
| cfg.train.lr_scheduler = { | |||
| 'type': 'LinearLR', | |||
| 'start_factor': 1.0, | |||
| 'end_factor': 0.0, | |||
| 'options': { | |||
| 'by_epoch': False | |||
| } | |||
| } | |||
| cfg.train.hooks = [{ | |||
| 'type': 'CheckpointHook', | |||
| 'interval': 1 | |||
| }, { | |||
| 'type': 'TextLoggerHook', | |||
| 'interval': 1 | |||
| }, { | |||
| 'type': 'IterTimerHook' | |||
| }, { | |||
| 'type': 'EvaluationHook', | |||
| 'by_epoch': False, | |||
| 'interval': 3000 | |||
| }] | |||
| return cfg | |||
| # load dataset | |||
| ds = MsDataset.load('passage-ranking-demo', 'zyznull') | |||
| train_ds = ds['train'].to_hf_dataset() | |||
| dev_ds = ds['train'].to_hf_dataset() | |||
| self.finetune( | |||
| model_id='damo/nlp_corom_passage-ranking_english-base', | |||
| train_dataset=train_ds, | |||
| eval_dataset=dev_ds, | |||
| cfg_modify_fn=cfg_modify_fn) | |||
| output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) | |||
| self.pipeline_passage_ranking(output_dir) | |||
| def pipeline_passage_ranking(self, model_dir): | |||
| model = Model.from_pretrained(model_dir) | |||
| pipeline_ins = pipeline(task=Tasks.passage_ranking, model=model) | |||
| print(pipeline_ins(input=self.inputs)) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||