Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10491951master
| @@ -236,7 +236,7 @@ class Pipelines(object): | |||
| conversational_text_to_sql = 'conversational-text-to-sql' | |||
| table_question_answering_pipeline = 'table-question-answering-pipeline' | |||
| sentence_embedding = 'sentence-embedding' | |||
| passage_ranking = 'passage-ranking' | |||
| text_ranking = 'text-ranking' | |||
| relation_extraction = 'relation-extraction' | |||
| document_segmentation = 'document-segmentation' | |||
| feature_extraction = 'feature-extraction' | |||
| @@ -297,7 +297,7 @@ class Trainers(object): | |||
| dialog_intent_trainer = 'dialog-intent-trainer' | |||
| nlp_base_trainer = 'nlp-base-trainer' | |||
| nlp_veco_trainer = 'nlp-veco-trainer' | |||
| nlp_passage_ranking_trainer = 'nlp-passage-ranking-trainer' | |||
| nlp_text_ranking_trainer = 'nlp-text-ranking-trainer' | |||
| # audio trainers | |||
| speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | |||
| @@ -343,7 +343,7 @@ class Preprocessors(object): | |||
| zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | |||
| text_error_correction = 'text-error-correction' | |||
| sentence_embedding = 'sentence-embedding' | |||
| passage_ranking = 'passage-ranking' | |||
| text_ranking = 'text-ranking' | |||
| sequence_labeling_tokenizer = 'sequence-labeling-tokenizer' | |||
| word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' | |||
| fill_mask = 'fill-mask' | |||
| @@ -34,8 +34,9 @@ if TYPE_CHECKING: | |||
| TaskModelForTextGeneration) | |||
| from .token_classification import SbertForTokenClassification | |||
| from .sentence_embedding import SentenceEmbedding | |||
| from .passage_ranking import PassageRanking | |||
| from .text_ranking import TextRanking | |||
| from .T5 import T5ForConditionalGeneration | |||
| else: | |||
| _import_structure = { | |||
| 'backbones': ['SbertModel'], | |||
| @@ -75,7 +76,7 @@ else: | |||
| 'token_classification': ['SbertForTokenClassification'], | |||
| 'table_question_answering': ['TableQuestionAnswering'], | |||
| 'sentence_embedding': ['SentenceEmbedding'], | |||
| 'passage_ranking': ['PassageRanking'], | |||
| 'text_ranking': ['TextRanking'], | |||
| 'T5': ['T5ForConditionalGeneration'], | |||
| } | |||
| @@ -13,18 +13,18 @@ from modelscope.models.nlp.structbert import SbertPreTrainedModel | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['PassageRanking'] | |||
| __all__ = ['TextRanking'] | |||
| @MODELS.register_module(Tasks.passage_ranking, module_name=Models.bert) | |||
| class PassageRanking(SbertForSequenceClassification, SbertPreTrainedModel): | |||
| @MODELS.register_module(Tasks.text_ranking, module_name=Models.bert) | |||
| class TextRanking(SbertForSequenceClassification, SbertPreTrainedModel): | |||
| base_model_prefix: str = 'bert' | |||
| supports_gradient_checkpointing = True | |||
| _keys_to_ignore_on_load_missing = [r'position_ids'] | |||
| def __init__(self, config, model_dir, *args, **kwargs): | |||
| if hasattr(config, 'base_model_prefix'): | |||
| PassageRanking.base_model_prefix = config.base_model_prefix | |||
| TextRanking.base_model_prefix = config.base_model_prefix | |||
| super().__init__(config, model_dir) | |||
| self.train_batch_size = kwargs.get('train_batch_size', 4) | |||
| self.register_buffer( | |||
| @@ -74,7 +74,7 @@ class PassageRanking(SbertForSequenceClassification, SbertPreTrainedModel): | |||
| num_labels = kwargs.get('num_labels', 1) | |||
| model_args = {} if num_labels is None else {'num_labels': num_labels} | |||
| return super(SbertPreTrainedModel, PassageRanking).from_pretrained( | |||
| return super(SbertPreTrainedModel, TextRanking).from_pretrained( | |||
| pretrained_model_name_or_path=kwargs.get('model_dir'), | |||
| model_dir=kwargs.get('model_dir'), | |||
| **model_args) | |||
| @@ -12,14 +12,14 @@ if TYPE_CHECKING: | |||
| from .movie_scene_segmentation import MovieSceneSegmentationDataset | |||
| from .video_summarization_dataset import VideoSummarizationDataset | |||
| from .image_inpainting import ImageInpaintingDataset | |||
| from .passage_ranking_dataset import PassageRankingDataset | |||
| from .text_ranking_dataset import TextRankingDataset | |||
| else: | |||
| _import_structure = { | |||
| 'base': ['TaskDataset'], | |||
| 'builder': ['TASK_DATASETS', 'build_task_dataset'], | |||
| 'torch_base_dataset': ['TorchTaskDataset'], | |||
| 'passage_ranking_dataset': ['PassageRankingDataset'], | |||
| 'text_ranking_dataset': ['TextRankingDataset'], | |||
| 'veco_dataset': ['VecoDataset'], | |||
| 'image_instance_segmentation_coco_dataset': | |||
| ['ImageInstanceSegmentationCocoDataset'], | |||
| @@ -16,8 +16,8 @@ from .torch_base_dataset import TorchTaskDataset | |||
| @TASK_DATASETS.register_module( | |||
| group_key=Tasks.passage_ranking, module_name=Models.bert) | |||
| class PassageRankingDataset(TorchTaskDataset): | |||
| group_key=Tasks.text_ranking, module_name=Models.bert) | |||
| class TextRankingDataset(TorchTaskDataset): | |||
| def __init__(self, | |||
| datasets: Union[Any, List[Any]], | |||
| @@ -35,8 +35,8 @@ class PassageRankingDataset(TorchTaskDataset): | |||
| 'positive_passages') | |||
| self.neg_sequence = self.dataset_config.get('neg_sequence', | |||
| 'negative_passages') | |||
| self.passage_text_fileds = self.dataset_config.get( | |||
| 'passage_text_fileds', ['title', 'text']) | |||
| self.text_fileds = self.dataset_config.get('text_fileds', | |||
| ['title', 'text']) | |||
| self.qid_field = self.dataset_config.get('qid_field', 'query_id') | |||
| if mode == ModeKeys.TRAIN: | |||
| train_config = kwargs.get('train', {}) | |||
| @@ -58,14 +58,14 @@ class PassageRankingDataset(TorchTaskDataset): | |||
| pos_sequences = group[self.pos_sequence] | |||
| pos_sequences = [ | |||
| ' '.join([ele[key] for key in self.passage_text_fileds]) | |||
| ' '.join([ele[key] for key in self.text_fileds]) | |||
| for ele in pos_sequences | |||
| ] | |||
| labels.extend([1] * len(pos_sequences)) | |||
| neg_sequences = group[self.neg_sequence] | |||
| neg_sequences = [ | |||
| ' '.join([ele[key] for key in self.passage_text_fileds]) | |||
| ' '.join([ele[key] for key in self.text_fileds]) | |||
| for ele in neg_sequences | |||
| ] | |||
| @@ -88,13 +88,13 @@ class PassageRankingDataset(TorchTaskDataset): | |||
| pos_sequences = group[self.pos_sequence] | |||
| pos_sequences = [ | |||
| ' '.join([ele[key] for key in self.passage_text_fileds]) | |||
| ' '.join([ele[key] for key in self.text_fileds]) | |||
| for ele in pos_sequences | |||
| ] | |||
| neg_sequences = group[self.neg_sequence] | |||
| neg_sequences = [ | |||
| ' '.join([ele[key] for key in self.passage_text_fileds]) | |||
| ' '.join([ele[key] for key in self.text_fileds]) | |||
| for ele in neg_sequences | |||
| ] | |||
| @@ -506,7 +506,7 @@ TASK_OUTPUTS = { | |||
| # } | |||
| Tasks.text_error_correction: [OutputKeys.OUTPUT], | |||
| Tasks.sentence_embedding: [OutputKeys.TEXT_EMBEDDING, OutputKeys.SCORES], | |||
| Tasks.passage_ranking: [OutputKeys.SCORES], | |||
| Tasks.text_ranking: [OutputKeys.SCORES], | |||
| # text generation result for single sample | |||
| # { | |||
| @@ -162,7 +162,7 @@ TASK_INPUTS = { | |||
| 'source_sentence': InputType.LIST, | |||
| 'sentences_to_compare': InputType.LIST, | |||
| }, | |||
| Tasks.passage_ranking: (InputType.TEXT, InputType.TEXT), | |||
| Tasks.text_ranking: (InputType.TEXT, InputType.TEXT), | |||
| Tasks.text_generation: | |||
| InputType.TEXT, | |||
| Tasks.fill_mask: | |||
| @@ -20,8 +20,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.sentence_embedding: | |||
| (Pipelines.sentence_embedding, | |||
| 'damo/nlp_corom_sentence-embedding_english-base'), | |||
| Tasks.passage_ranking: (Pipelines.passage_ranking, | |||
| 'damo/nlp_corom_passage-ranking_english-base'), | |||
| Tasks.text_ranking: (Pipelines.text_ranking, | |||
| 'damo/nlp_corom_passage-ranking_english-base'), | |||
| Tasks.word_segmentation: | |||
| (Pipelines.word_segmentation, | |||
| 'damo/nlp_structbert_word-segmentation_chinese-base'), | |||
| @@ -17,7 +17,7 @@ if TYPE_CHECKING: | |||
| from .fill_mask_ponet_pipeline import FillMaskPonetPipeline | |||
| from .information_extraction_pipeline import InformationExtractionPipeline | |||
| from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline | |||
| from .passage_ranking_pipeline import PassageRankingPipeline | |||
| from .text_ranking_pipeline import TextRankingPipeline | |||
| from .sentence_embedding_pipeline import SentenceEmbeddingPipeline | |||
| from .sequence_classification_pipeline import SequenceClassificationPipeline | |||
| from .summarization_pipeline import SummarizationPipeline | |||
| @@ -51,7 +51,7 @@ else: | |||
| 'information_extraction_pipeline': ['InformationExtractionPipeline'], | |||
| 'named_entity_recognition_pipeline': | |||
| ['NamedEntityRecognitionPipeline'], | |||
| 'passage_ranking_pipeline': ['PassageRankingPipeline'], | |||
| 'text_ranking_pipeline': ['TextRankingPipeline'], | |||
| 'sentence_embedding_pipeline': ['SentenceEmbeddingPipeline'], | |||
| 'sequence_classification_pipeline': ['SequenceClassificationPipeline'], | |||
| 'summarization_pipeline': ['SummarizationPipeline'], | |||
| @@ -9,15 +9,15 @@ from modelscope.models import Model | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import PassageRankingPreprocessor, Preprocessor | |||
| from modelscope.preprocessors import Preprocessor, TextRankingPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['PassageRankingPipeline'] | |||
| __all__ = ['TextRankingPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.passage_ranking, module_name=Pipelines.passage_ranking) | |||
| class PassageRankingPipeline(Pipeline): | |||
| Tasks.text_ranking, module_name=Pipelines.text_ranking) | |||
| class TextRankingPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[Model, str], | |||
| @@ -36,7 +36,7 @@ class PassageRankingPipeline(Pipeline): | |||
| Model) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = PassageRankingPreprocessor( | |||
| preprocessor = TextRankingPreprocessor( | |||
| model.model_dir if isinstance(model, Model) else model, | |||
| sequence_length=kwargs.pop('sequence_length', 128)) | |||
| model.eval() | |||
| @@ -21,7 +21,7 @@ if TYPE_CHECKING: | |||
| FillMaskPoNetPreprocessor, | |||
| NLPPreprocessor, | |||
| NLPTokenizerPreprocessorBase, | |||
| PassageRankingPreprocessor, | |||
| TextRankingPreprocessor, | |||
| RelationExtractionPreprocessor, | |||
| SentenceEmbeddingPreprocessor, | |||
| SequenceClassificationPreprocessor, | |||
| @@ -62,7 +62,7 @@ else: | |||
| 'FillMaskPoNetPreprocessor', | |||
| 'NLPPreprocessor', | |||
| 'NLPTokenizerPreprocessorBase', | |||
| 'PassageRankingPreprocessor', | |||
| 'TextRankingPreprocessor', | |||
| 'RelationExtractionPreprocessor', | |||
| 'SentenceEmbeddingPreprocessor', | |||
| 'SequenceClassificationPreprocessor', | |||
| @@ -11,7 +11,7 @@ if TYPE_CHECKING: | |||
| FillMaskPoNetPreprocessor, | |||
| NLPPreprocessor, | |||
| NLPTokenizerPreprocessorBase, | |||
| PassageRankingPreprocessor, | |||
| TextRankingPreprocessor, | |||
| RelationExtractionPreprocessor, | |||
| SentenceEmbeddingPreprocessor, | |||
| SequenceClassificationPreprocessor, | |||
| @@ -33,7 +33,7 @@ else: | |||
| 'FillMaskPoNetPreprocessor', | |||
| 'NLPPreprocessor', | |||
| 'NLPTokenizerPreprocessorBase', | |||
| 'PassageRankingPreprocessor', | |||
| 'TextRankingPreprocessor', | |||
| 'RelationExtractionPreprocessor', | |||
| 'SentenceEmbeddingPreprocessor', | |||
| 'SequenceClassificationPreprocessor', | |||
| @@ -29,7 +29,7 @@ __all__ = [ | |||
| 'NLPPreprocessor', | |||
| 'FillMaskPoNetPreprocessor', | |||
| 'NLPTokenizerPreprocessorBase', | |||
| 'PassageRankingPreprocessor', | |||
| 'TextRankingPreprocessor', | |||
| 'RelationExtractionPreprocessor', | |||
| 'SentenceEmbeddingPreprocessor', | |||
| 'SequenceClassificationPreprocessor', | |||
| @@ -245,9 +245,9 @@ class NLPPreprocessor(NLPTokenizerPreprocessorBase): | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.passage_ranking) | |||
| class PassageRankingPreprocessor(NLPTokenizerPreprocessorBase): | |||
| """The tokenizer preprocessor used in passage ranking model. | |||
| Fields.nlp, module_name=Preprocessors.text_ranking) | |||
| class TextRankingPreprocessor(NLPTokenizerPreprocessorBase): | |||
| """The tokenizer preprocessor used in text-ranking model. | |||
| """ | |||
| def __init__(self, | |||
| @@ -11,7 +11,7 @@ if TYPE_CHECKING: | |||
| ImagePortraitEnhancementTrainer, | |||
| MovieSceneSegmentationTrainer, ImageInpaintingTrainer) | |||
| from .multi_modal import CLIPTrainer | |||
| from .nlp import SequenceClassificationTrainer, PassageRankingTrainer | |||
| from .nlp import SequenceClassificationTrainer, TextRankingTrainer | |||
| from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer | |||
| from .trainer import EpochBasedTrainer | |||
| @@ -26,7 +26,7 @@ else: | |||
| 'ImageInpaintingTrainer' | |||
| ], | |||
| 'multi_modal': ['CLIPTrainer'], | |||
| 'nlp': ['SequenceClassificationTrainer', 'PassageRankingTrainer'], | |||
| 'nlp': ['SequenceClassificationTrainer', 'TextRankingTrainer'], | |||
| 'nlp_trainer': ['NlpEpochBasedTrainer', 'VecoTrainer'], | |||
| 'trainer': ['EpochBasedTrainer'] | |||
| } | |||
| @@ -6,12 +6,12 @@ from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .sequence_classification_trainer import SequenceClassificationTrainer | |||
| from .csanmt_translation_trainer import CsanmtTranslationTrainer | |||
| from .passage_ranking_trainer import PassageRankingTranier | |||
| from .text_ranking_trainer import TextRankingTranier | |||
| else: | |||
| _import_structure = { | |||
| 'sequence_classification_trainer': ['SequenceClassificationTrainer'], | |||
| 'csanmt_translation_trainer': ['CsanmtTranslationTrainer'], | |||
| 'passage_ranking_trainer': ['PassageRankingTrainer'] | |||
| 'text_ranking_trainer': ['TextRankingTrainer'] | |||
| } | |||
| import sys | |||
| @@ -8,6 +8,7 @@ import numpy as np | |||
| import torch | |||
| from torch import nn | |||
| from torch.utils.data import DataLoader, Dataset | |||
| from tqdm import tqdm | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.models.base import Model, TorchModel | |||
| @@ -42,8 +43,8 @@ class GroupCollator(): | |||
| return batch | |||
| @TRAINERS.register_module(module_name=Trainers.nlp_passage_ranking_trainer) | |||
| class PassageRankingTrainer(NlpEpochBasedTrainer): | |||
| @TRAINERS.register_module(module_name=Trainers.nlp_text_ranking_trainer) | |||
| class TextRankingTrainer(NlpEpochBasedTrainer): | |||
| def __init__( | |||
| self, | |||
| @@ -117,7 +118,7 @@ class PassageRankingTrainer(NlpEpochBasedTrainer): | |||
| Example: | |||
| {"accuracy": 0.5091743119266054, "f1": 0.673780487804878} | |||
| """ | |||
| from modelscope.models.nlp import PassageRanking | |||
| from modelscope.models.nlp import TextRanking | |||
| # get the raw online dataset | |||
| self.eval_dataloader = self._build_dataloader_with_dataset( | |||
| self.eval_dataset, | |||
| @@ -126,7 +127,7 @@ class PassageRankingTrainer(NlpEpochBasedTrainer): | |||
| # generate a standard dataloader | |||
| # generate a model | |||
| if checkpoint_path is not None: | |||
| model = PassageRanking.from_pretrained(checkpoint_path) | |||
| model = TextRanking.from_pretrained(checkpoint_path) | |||
| else: | |||
| model = self.model | |||
| @@ -141,7 +142,7 @@ class PassageRankingTrainer(NlpEpochBasedTrainer): | |||
| total_spent_time = 0.0 | |||
| device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |||
| model.to(device) | |||
| for _step, batch in enumerate(self.eval_dataloader): | |||
| for _step, batch in enumerate(tqdm(self.eval_dataloader)): | |||
| try: | |||
| batch = { | |||
| key: | |||
| @@ -103,7 +103,7 @@ class NLPTasks(object): | |||
| sentence_similarity = 'sentence-similarity' | |||
| text_classification = 'text-classification' | |||
| sentence_embedding = 'sentence-embedding' | |||
| passage_ranking = 'passage-ranking' | |||
| text_ranking = 'text-ranking' | |||
| relation_extraction = 'relation-extraction' | |||
| zero_shot = 'zero-shot' | |||
| translation = 'translation' | |||
| @@ -4,15 +4,15 @@ import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import PassageRanking | |||
| from modelscope.models.nlp import TextRanking | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import PassageRankingPipeline | |||
| from modelscope.preprocessors import PassageRankingPreprocessor | |||
| from modelscope.pipelines.nlp import TextRankingPipeline | |||
| from modelscope.preprocessors import TextRankingPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class PassageRankingTest(unittest.TestCase): | |||
| class TextRankingTest(unittest.TestCase): | |||
| model_id = 'damo/nlp_corom_passage-ranking_english-base' | |||
| inputs = { | |||
| 'source_sentence': ["how long it take to get a master's degree"], | |||
| @@ -27,11 +27,11 @@ class PassageRankingTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_by_direct_model_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| tokenizer = PassageRankingPreprocessor(cache_path) | |||
| model = PassageRanking.from_pretrained(cache_path) | |||
| pipeline1 = PassageRankingPipeline(model, preprocessor=tokenizer) | |||
| tokenizer = TextRankingPreprocessor(cache_path) | |||
| model = TextRanking.from_pretrained(cache_path) | |||
| pipeline1 = TextRankingPipeline(model, preprocessor=tokenizer) | |||
| pipeline2 = pipeline( | |||
| Tasks.passage_ranking, model=model, preprocessor=tokenizer) | |||
| Tasks.text_ranking, model=model, preprocessor=tokenizer) | |||
| print(f'sentence: {self.inputs}\n' | |||
| f'pipeline1:{pipeline1(input=self.inputs)}') | |||
| print() | |||
| @@ -40,20 +40,19 @@ class PassageRankingTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| tokenizer = PassageRankingPreprocessor(model.model_dir) | |||
| tokenizer = TextRankingPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.passage_ranking, model=model, preprocessor=tokenizer) | |||
| task=Tasks.text_ranking, model=model, preprocessor=tokenizer) | |||
| print(pipeline_ins(input=self.inputs)) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.passage_ranking, model=self.model_id) | |||
| pipeline_ins = pipeline(task=Tasks.text_ranking, model=self.model_id) | |||
| print(pipeline_ins(input=self.inputs)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline(task=Tasks.passage_ranking) | |||
| pipeline_ins = pipeline(task=Tasks.text_ranking) | |||
| print(pipeline_ins(input=self.inputs)) | |||
| @@ -41,7 +41,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| model_id, | |||
| train_dataset, | |||
| eval_dataset, | |||
| name=Trainers.nlp_passage_ranking_trainer, | |||
| name=Trainers.nlp_text_ranking_trainer, | |||
| cfg_modify_fn=None, | |||
| **kwargs): | |||
| kwargs = dict( | |||
| @@ -61,8 +61,8 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| def test_finetune_msmarco(self): | |||
| def cfg_modify_fn(cfg): | |||
| cfg.task = 'passage-ranking' | |||
| cfg['preprocessor'] = {'type': 'passage-ranking'} | |||
| cfg.task = 'text-ranking' | |||
| cfg['preprocessor'] = {'type': 'text-ranking'} | |||
| cfg.train.optimizer.lr = 2e-5 | |||
| cfg['dataset'] = { | |||
| 'train': { | |||
| @@ -105,7 +105,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| }, { | |||
| 'type': 'EvaluationHook', | |||
| 'by_epoch': False, | |||
| 'interval': 3000 | |||
| 'interval': 15 | |||
| }] | |||
| return cfg | |||
| @@ -114,18 +114,19 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| train_ds = ds['train'].to_hf_dataset() | |||
| dev_ds = ds['train'].to_hf_dataset() | |||
| model_id = 'damo/nlp_corom_passage-ranking_english-base' | |||
| self.finetune( | |||
| model_id='damo/nlp_corom_passage-ranking_english-base', | |||
| model_id=model_id, | |||
| train_dataset=train_ds, | |||
| eval_dataset=dev_ds, | |||
| cfg_modify_fn=cfg_modify_fn) | |||
| output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) | |||
| self.pipeline_passage_ranking(output_dir) | |||
| self.pipeline_text_ranking(output_dir) | |||
| def pipeline_passage_ranking(self, model_dir): | |||
| def pipeline_text_ranking(self, model_dir): | |||
| model = Model.from_pretrained(model_dir) | |||
| pipeline_ins = pipeline(task=Tasks.passage_ranking, model=model) | |||
| pipeline_ins = pipeline(task=Tasks.text_ranking, model=model) | |||
| print(pipeline_ins(input=self.inputs)) | |||