| @@ -236,7 +236,7 @@ class Pipelines(object): | |||
| conversational_text_to_sql = 'conversational-text-to-sql' | |||
| table_question_answering_pipeline = 'table-question-answering-pipeline' | |||
| sentence_embedding = 'sentence-embedding' | |||
| passage_ranking = 'passage-ranking' | |||
| text_ranking = 'text-ranking' | |||
| relation_extraction = 'relation-extraction' | |||
| document_segmentation = 'document-segmentation' | |||
| feature_extraction = 'feature-extraction' | |||
| @@ -298,7 +298,7 @@ class Trainers(object): | |||
| dialog_intent_trainer = 'dialog-intent-trainer' | |||
| nlp_base_trainer = 'nlp-base-trainer' | |||
| nlp_veco_trainer = 'nlp-veco-trainer' | |||
| nlp_passage_ranking_trainer = 'nlp-passage-ranking-trainer' | |||
| nlp_text_ranking_trainer = 'nlp-text-ranking-trainer' | |||
| # audio trainers | |||
| speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | |||
| @@ -344,7 +344,7 @@ class Preprocessors(object): | |||
| zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | |||
| text_error_correction = 'text-error-correction' | |||
| sentence_embedding = 'sentence-embedding' | |||
| passage_ranking = 'passage-ranking' | |||
| text_ranking = 'text-ranking' | |||
| sequence_labeling_tokenizer = 'sequence-labeling-tokenizer' | |||
| word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' | |||
| fill_mask = 'fill-mask' | |||
| @@ -56,9 +56,6 @@ class OneStageDetector(nn.Module): | |||
| def inference(self, meta): | |||
| with torch.no_grad(): | |||
| torch.cuda.synchronize() | |||
| preds = self(meta['img']) | |||
| torch.cuda.synchronize() | |||
| results = self.head.post_process(preds, meta) | |||
| torch.cuda.synchronize() | |||
| return results | |||
| @@ -3,9 +3,9 @@ from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.constant import Tasks | |||
| OFA_TASK_KEY_MAPPING = { | |||
| Tasks.ofa_ocr_recognition: OutputKeys.TEXT, | |||
| Tasks.ocr_recognition: OutputKeys.TEXT, | |||
| Tasks.image_captioning: OutputKeys.CAPTION, | |||
| Tasks.summarization: OutputKeys.TEXT, | |||
| Tasks.text_summarization: OutputKeys.TEXT, | |||
| Tasks.visual_question_answering: OutputKeys.TEXT, | |||
| Tasks.visual_grounding: OutputKeys.BOXES, | |||
| Tasks.text_classification: OutputKeys.LABELS, | |||
| @@ -28,13 +28,13 @@ __all__ = ['OfaForAllTasks'] | |||
| @MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa) | |||
| @MODELS.register_module(Tasks.ofa_ocr_recognition, module_name=Models.ofa) | |||
| @MODELS.register_module(Tasks.ocr_recognition, module_name=Models.ofa) | |||
| @MODELS.register_module(Tasks.visual_grounding, module_name=Models.ofa) | |||
| @MODELS.register_module( | |||
| Tasks.visual_question_answering, module_name=Models.ofa) | |||
| @MODELS.register_module(Tasks.visual_entailment, module_name=Models.ofa) | |||
| @MODELS.register_module(Tasks.image_classification, module_name=Models.ofa) | |||
| @MODELS.register_module(Tasks.summarization, module_name=Models.ofa) | |||
| @MODELS.register_module(Tasks.text_summarization, module_name=Models.ofa) | |||
| @MODELS.register_module(Tasks.text_classification, module_name=Models.ofa) | |||
| class OfaForAllTasks(TorchModel): | |||
| @@ -98,9 +98,9 @@ class OfaForAllTasks(TorchModel): | |||
| 'traverse': self._traverse_inference, | |||
| } | |||
| self.task_inference_mapping = { | |||
| Tasks.ofa_ocr_recognition: self._text_gen_inference, | |||
| Tasks.ocr_recognition: self._text_gen_inference, | |||
| Tasks.image_captioning: self._text_gen_inference, | |||
| Tasks.summarization: self._text_gen_inference, | |||
| Tasks.text_summarization: self._text_gen_inference, | |||
| Tasks.visual_grounding: self._visual_grounding_inference, | |||
| Tasks.visual_entailment: inference_d[self.gen_type], | |||
| Tasks.visual_question_answering: inference_d[self.gen_type], | |||
| @@ -34,8 +34,9 @@ if TYPE_CHECKING: | |||
| TaskModelForTextGeneration) | |||
| from .token_classification import SbertForTokenClassification | |||
| from .sentence_embedding import SentenceEmbedding | |||
| from .passage_ranking import PassageRanking | |||
| from .text_ranking import TextRanking | |||
| from .T5 import T5ForConditionalGeneration | |||
| else: | |||
| _import_structure = { | |||
| 'backbones': ['SbertModel'], | |||
| @@ -75,7 +76,7 @@ else: | |||
| 'token_classification': ['SbertForTokenClassification'], | |||
| 'table_question_answering': ['TableQuestionAnswering'], | |||
| 'sentence_embedding': ['SentenceEmbedding'], | |||
| 'passage_ranking': ['PassageRanking'], | |||
| 'text_ranking': ['TextRanking'], | |||
| 'T5': ['T5ForConditionalGeneration'], | |||
| } | |||
| @@ -10,6 +10,8 @@ from modelscope.utils.constant import Tasks | |||
| @HEADS.register_module( | |||
| Tasks.information_extraction, module_name=Heads.information_extraction) | |||
| @HEADS.register_module( | |||
| Tasks.relation_extraction, module_name=Heads.information_extraction) | |||
| class InformationExtractionHead(TorchHead): | |||
| def __init__(self, **kwargs): | |||
| @@ -16,6 +16,8 @@ __all__ = ['InformationExtractionModel'] | |||
| @MODELS.register_module( | |||
| Tasks.information_extraction, | |||
| module_name=TaskModels.information_extraction) | |||
| @MODELS.register_module( | |||
| Tasks.relation_extraction, module_name=TaskModels.information_extraction) | |||
| class InformationExtractionModel(SingleBackboneTaskModelBase): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| @@ -13,18 +13,18 @@ from modelscope.models.nlp.structbert import SbertPreTrainedModel | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['PassageRanking'] | |||
| __all__ = ['TextRanking'] | |||
| @MODELS.register_module(Tasks.passage_ranking, module_name=Models.bert) | |||
| class PassageRanking(SbertForSequenceClassification, SbertPreTrainedModel): | |||
| @MODELS.register_module(Tasks.text_ranking, module_name=Models.bert) | |||
| class TextRanking(SbertForSequenceClassification, SbertPreTrainedModel): | |||
| base_model_prefix: str = 'bert' | |||
| supports_gradient_checkpointing = True | |||
| _keys_to_ignore_on_load_missing = [r'position_ids'] | |||
| def __init__(self, config, model_dir, *args, **kwargs): | |||
| if hasattr(config, 'base_model_prefix'): | |||
| PassageRanking.base_model_prefix = config.base_model_prefix | |||
| TextRanking.base_model_prefix = config.base_model_prefix | |||
| super().__init__(config, model_dir) | |||
| self.train_batch_size = kwargs.get('train_batch_size', 4) | |||
| self.register_buffer( | |||
| @@ -74,7 +74,7 @@ class PassageRanking(SbertForSequenceClassification, SbertPreTrainedModel): | |||
| num_labels = kwargs.get('num_labels', 1) | |||
| model_args = {} if num_labels is None else {'num_labels': num_labels} | |||
| return super(SbertPreTrainedModel, PassageRanking).from_pretrained( | |||
| return super(SbertPreTrainedModel, TextRanking).from_pretrained( | |||
| pretrained_model_name_or_path=kwargs.get('model_dir'), | |||
| model_dir=kwargs.get('model_dir'), | |||
| **model_args) | |||
| @@ -12,14 +12,14 @@ if TYPE_CHECKING: | |||
| from .movie_scene_segmentation import MovieSceneSegmentationDataset | |||
| from .video_summarization_dataset import VideoSummarizationDataset | |||
| from .image_inpainting import ImageInpaintingDataset | |||
| from .passage_ranking_dataset import PassageRankingDataset | |||
| from .text_ranking_dataset import TextRankingDataset | |||
| else: | |||
| _import_structure = { | |||
| 'base': ['TaskDataset'], | |||
| 'builder': ['TASK_DATASETS', 'build_task_dataset'], | |||
| 'torch_base_dataset': ['TorchTaskDataset'], | |||
| 'passage_ranking_dataset': ['PassageRankingDataset'], | |||
| 'text_ranking_dataset': ['TextRankingDataset'], | |||
| 'veco_dataset': ['VecoDataset'], | |||
| 'image_instance_segmentation_coco_dataset': | |||
| ['ImageInstanceSegmentationCocoDataset'], | |||
| @@ -16,8 +16,8 @@ from .torch_base_dataset import TorchTaskDataset | |||
| @TASK_DATASETS.register_module( | |||
| group_key=Tasks.passage_ranking, module_name=Models.bert) | |||
| class PassageRankingDataset(TorchTaskDataset): | |||
| group_key=Tasks.text_ranking, module_name=Models.bert) | |||
| class TextRankingDataset(TorchTaskDataset): | |||
| def __init__(self, | |||
| datasets: Union[Any, List[Any]], | |||
| @@ -35,8 +35,8 @@ class PassageRankingDataset(TorchTaskDataset): | |||
| 'positive_passages') | |||
| self.neg_sequence = self.dataset_config.get('neg_sequence', | |||
| 'negative_passages') | |||
| self.passage_text_fileds = self.dataset_config.get( | |||
| 'passage_text_fileds', ['title', 'text']) | |||
| self.text_fileds = self.dataset_config.get('text_fileds', | |||
| ['title', 'text']) | |||
| self.qid_field = self.dataset_config.get('qid_field', 'query_id') | |||
| if mode == ModeKeys.TRAIN: | |||
| train_config = kwargs.get('train', {}) | |||
| @@ -58,14 +58,14 @@ class PassageRankingDataset(TorchTaskDataset): | |||
| pos_sequences = group[self.pos_sequence] | |||
| pos_sequences = [ | |||
| ' '.join([ele[key] for key in self.passage_text_fileds]) | |||
| ' '.join([ele[key] for key in self.text_fileds]) | |||
| for ele in pos_sequences | |||
| ] | |||
| labels.extend([1] * len(pos_sequences)) | |||
| neg_sequences = group[self.neg_sequence] | |||
| neg_sequences = [ | |||
| ' '.join([ele[key] for key in self.passage_text_fileds]) | |||
| ' '.join([ele[key] for key in self.text_fileds]) | |||
| for ele in neg_sequences | |||
| ] | |||
| @@ -88,13 +88,13 @@ class PassageRankingDataset(TorchTaskDataset): | |||
| pos_sequences = group[self.pos_sequence] | |||
| pos_sequences = [ | |||
| ' '.join([ele[key] for key in self.passage_text_fileds]) | |||
| ' '.join([ele[key] for key in self.text_fileds]) | |||
| for ele in pos_sequences | |||
| ] | |||
| neg_sequences = group[self.neg_sequence] | |||
| neg_sequences = [ | |||
| ' '.join([ele[key] for key in self.passage_text_fileds]) | |||
| ' '.join([ele[key] for key in self.text_fileds]) | |||
| for ele in neg_sequences | |||
| ] | |||
| @@ -506,7 +506,7 @@ TASK_OUTPUTS = { | |||
| # } | |||
| Tasks.text_error_correction: [OutputKeys.OUTPUT], | |||
| Tasks.sentence_embedding: [OutputKeys.TEXT_EMBEDDING, OutputKeys.SCORES], | |||
| Tasks.passage_ranking: [OutputKeys.SCORES], | |||
| Tasks.text_ranking: [OutputKeys.SCORES], | |||
| # text generation result for single sample | |||
| # { | |||
| @@ -661,7 +661,7 @@ TASK_OUTPUTS = { | |||
| # "caption": "this is an image caption text." | |||
| # } | |||
| Tasks.image_captioning: [OutputKeys.CAPTION], | |||
| Tasks.ofa_ocr_recognition: [OutputKeys.TEXT], | |||
| Tasks.ocr_recognition: [OutputKeys.TEXT], | |||
| # visual grounding result for single sample | |||
| # { | |||
| @@ -162,7 +162,7 @@ TASK_INPUTS = { | |||
| 'source_sentence': InputType.LIST, | |||
| 'sentences_to_compare': InputType.LIST, | |||
| }, | |||
| Tasks.passage_ranking: (InputType.TEXT, InputType.TEXT), | |||
| Tasks.text_ranking: (InputType.TEXT, InputType.TEXT), | |||
| Tasks.text_generation: | |||
| InputType.TEXT, | |||
| Tasks.fill_mask: | |||
| @@ -20,8 +20,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.sentence_embedding: | |||
| (Pipelines.sentence_embedding, | |||
| 'damo/nlp_corom_sentence-embedding_english-base'), | |||
| Tasks.passage_ranking: (Pipelines.passage_ranking, | |||
| 'damo/nlp_corom_passage-ranking_english-base'), | |||
| Tasks.text_ranking: (Pipelines.text_ranking, | |||
| 'damo/nlp_corom_passage-ranking_english-base'), | |||
| Tasks.word_segmentation: | |||
| (Pipelines.word_segmentation, | |||
| 'damo/nlp_structbert_word-segmentation_chinese-base'), | |||
| @@ -31,6 +31,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.named_entity_recognition: | |||
| (Pipelines.named_entity_recognition, | |||
| 'damo/nlp_raner_named-entity-recognition_chinese-base-news'), | |||
| Tasks.relation_extraction: | |||
| (Pipelines.relation_extraction, | |||
| 'damo/nlp_bert_relation-extraction_chinese-base'), | |||
| Tasks.information_extraction: | |||
| (Pipelines.relation_extraction, | |||
| 'damo/nlp_bert_relation-extraction_chinese-base'), | |||
| @@ -61,6 +61,8 @@ class FaceImageGenerationPipeline(Pipeline): | |||
| return input | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| if isinstance(input, str): | |||
| input = int(input) | |||
| assert isinstance(input, int) | |||
| torch.manual_seed(input) | |||
| torch.cuda.manual_seed(input) | |||
| @@ -11,6 +11,8 @@ from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.image_text_retrieval, module_name=Pipelines.multi_modal_embedding) | |||
| @PIPELINES.register_module( | |||
| Tasks.multi_modal_embedding, module_name=Pipelines.multi_modal_embedding) | |||
| class MultiModalEmbeddingPipeline(Pipeline): | |||
| @@ -16,7 +16,7 @@ logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.ofa_ocr_recognition, module_name=Pipelines.ofa_ocr_recognition) | |||
| Tasks.ocr_recognition, module_name=Pipelines.ofa_ocr_recognition) | |||
| class OcrRecognitionPipeline(Pipeline): | |||
| def __init__(self, | |||
| @@ -17,7 +17,7 @@ if TYPE_CHECKING: | |||
| from .fill_mask_ponet_pipeline import FillMaskPonetPipeline | |||
| from .information_extraction_pipeline import InformationExtractionPipeline | |||
| from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline | |||
| from .passage_ranking_pipeline import PassageRankingPipeline | |||
| from .text_ranking_pipeline import TextRankingPipeline | |||
| from .sentence_embedding_pipeline import SentenceEmbeddingPipeline | |||
| from .sequence_classification_pipeline import SequenceClassificationPipeline | |||
| from .summarization_pipeline import SummarizationPipeline | |||
| @@ -51,7 +51,7 @@ else: | |||
| 'information_extraction_pipeline': ['InformationExtractionPipeline'], | |||
| 'named_entity_recognition_pipeline': | |||
| ['NamedEntityRecognitionPipeline'], | |||
| 'passage_ranking_pipeline': ['PassageRankingPipeline'], | |||
| 'text_ranking_pipeline': ['TextRankingPipeline'], | |||
| 'sentence_embedding_pipeline': ['SentenceEmbeddingPipeline'], | |||
| 'sequence_classification_pipeline': ['SequenceClassificationPipeline'], | |||
| 'summarization_pipeline': ['SummarizationPipeline'], | |||
| @@ -17,6 +17,8 @@ __all__ = ['InformationExtractionPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.information_extraction, module_name=Pipelines.relation_extraction) | |||
| @PIPELINES.register_module( | |||
| Tasks.relation_extraction, module_name=Pipelines.relation_extraction) | |||
| class InformationExtractionPipeline(Pipeline): | |||
| def __init__(self, | |||
| @@ -13,7 +13,7 @@ logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.summarization, module_name=Pipelines.text_generation) | |||
| Tasks.text_summarization, module_name=Pipelines.text_generation) | |||
| class SummarizationPipeline(Pipeline): | |||
| def __init__(self, | |||
| @@ -72,6 +72,7 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||
| action = self.action_ops[result['action']] | |||
| headers = table['header_name'] | |||
| current_sql = result['sql'] | |||
| current_sql['from'] = [table['table_id']] | |||
| if history_sql is None: | |||
| return current_sql | |||
| @@ -216,10 +217,11 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||
| else: | |||
| return current_sql | |||
| def sql_dict_to_str(self, result, table): | |||
| def sql_dict_to_str(self, result, tables): | |||
| """ | |||
| convert sql struct to string | |||
| """ | |||
| table = tables[result['sql']['from'][0]] | |||
| header_names = table['header_name'] + ['空列'] | |||
| header_ids = table['header_id'] + ['null'] | |||
| sql = result['sql'] | |||
| @@ -279,42 +281,43 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||
| """ | |||
| result = inputs['result'] | |||
| history_sql = inputs['history_sql'] | |||
| result['sql'] = self.post_process_multi_turn( | |||
| history_sql=history_sql, | |||
| result=result, | |||
| table=self.db.tables[result['table_id']]) | |||
| result['sql']['from'] = [result['table_id']] | |||
| sql = self.sql_dict_to_str( | |||
| result=result, table=self.db.tables[result['table_id']]) | |||
| try: | |||
| result['sql'] = self.post_process_multi_turn( | |||
| history_sql=history_sql, | |||
| result=result, | |||
| table=self.db.tables[result['table_id']]) | |||
| except Exception: | |||
| result['sql'] = history_sql | |||
| sql = self.sql_dict_to_str(result=result, tables=self.db.tables) | |||
| # add sqlite | |||
| if self.db.is_use_sqlite: | |||
| try: | |||
| cursor = self.db.connection_obj.cursor().execute(sql.query) | |||
| names = [{ | |||
| 'name': | |||
| description[0], | |||
| 'label': | |||
| self.db.tables[result['table_id']]['headerid2name'].get( | |||
| description[0], description[0]) | |||
| } for description in cursor.description] | |||
| cells = [] | |||
| header_ids, header_names = [], [] | |||
| for description in cursor.description: | |||
| header_ids.append(self.db.tables[result['table_id']] | |||
| ['headerid2name'].get( | |||
| description[0], description[0])) | |||
| header_names.append(description[0]) | |||
| rows = [] | |||
| for res in cursor.fetchall(): | |||
| row = {} | |||
| for name, cell in zip(names, res): | |||
| row[name['name']] = cell | |||
| cells.append(row) | |||
| tabledata = {'headers': names, 'cells': cells} | |||
| rows.append(list(res)) | |||
| tabledata = { | |||
| 'header_id': header_ids, | |||
| 'header_name': header_names, | |||
| 'rows': rows | |||
| } | |||
| except Exception: | |||
| tabledata = {'headers': [], 'cells': []} | |||
| tabledata = {'header_id': [], 'header_name': [], 'rows': []} | |||
| else: | |||
| tabledata = {'headers': [], 'cells': []} | |||
| tabledata = {'header_id': [], 'header_name': [], 'rows': []} | |||
| output = { | |||
| OutputKeys.SQL_STRING: sql.string, | |||
| OutputKeys.SQL_QUERY: sql.query, | |||
| OutputKeys.HISTORY: result['sql'], | |||
| OutputKeys.QUERT_RESULT: json.dumps(tabledata, ensure_ascii=False), | |||
| OutputKeys.QUERT_RESULT: tabledata, | |||
| } | |||
| return output | |||
| @@ -9,15 +9,15 @@ from modelscope.models import Model | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import PassageRankingPreprocessor, Preprocessor | |||
| from modelscope.preprocessors import Preprocessor, TextRankingPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['PassageRankingPipeline'] | |||
| __all__ = ['TextRankingPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.passage_ranking, module_name=Pipelines.passage_ranking) | |||
| class PassageRankingPipeline(Pipeline): | |||
| Tasks.text_ranking, module_name=Pipelines.text_ranking) | |||
| class TextRankingPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[Model, str], | |||
| @@ -36,7 +36,7 @@ class PassageRankingPipeline(Pipeline): | |||
| Model) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = PassageRankingPreprocessor( | |||
| preprocessor = TextRankingPreprocessor( | |||
| model.model_dir if isinstance(model, Model) else model, | |||
| sequence_length=kwargs.pop('sequence_length', 128)) | |||
| model.eval() | |||
| @@ -21,7 +21,7 @@ if TYPE_CHECKING: | |||
| FillMaskPoNetPreprocessor, | |||
| NLPPreprocessor, | |||
| NLPTokenizerPreprocessorBase, | |||
| PassageRankingPreprocessor, | |||
| TextRankingPreprocessor, | |||
| RelationExtractionPreprocessor, | |||
| SentenceEmbeddingPreprocessor, | |||
| SequenceClassificationPreprocessor, | |||
| @@ -62,7 +62,7 @@ else: | |||
| 'FillMaskPoNetPreprocessor', | |||
| 'NLPPreprocessor', | |||
| 'NLPTokenizerPreprocessorBase', | |||
| 'PassageRankingPreprocessor', | |||
| 'TextRankingPreprocessor', | |||
| 'RelationExtractionPreprocessor', | |||
| 'SentenceEmbeddingPreprocessor', | |||
| 'SequenceClassificationPreprocessor', | |||
| @@ -40,7 +40,7 @@ class OfaPreprocessor(Preprocessor): | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| preprocess_mapping = { | |||
| Tasks.ofa_ocr_recognition: OfaOcrRecognitionPreprocessor, | |||
| Tasks.ocr_recognition: OfaOcrRecognitionPreprocessor, | |||
| Tasks.image_captioning: OfaImageCaptioningPreprocessor, | |||
| Tasks.visual_grounding: OfaVisualGroundingPreprocessor, | |||
| Tasks.visual_question_answering: | |||
| @@ -48,14 +48,14 @@ class OfaPreprocessor(Preprocessor): | |||
| Tasks.visual_entailment: OfaVisualEntailmentPreprocessor, | |||
| Tasks.image_classification: OfaImageClassificationPreprocessor, | |||
| Tasks.text_classification: OfaTextClassificationPreprocessor, | |||
| Tasks.summarization: OfaSummarizationPreprocessor, | |||
| Tasks.text_summarization: OfaSummarizationPreprocessor, | |||
| Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor | |||
| } | |||
| input_key_mapping = { | |||
| Tasks.ofa_ocr_recognition: ['image'], | |||
| Tasks.ocr_recognition: ['image'], | |||
| Tasks.image_captioning: ['image'], | |||
| Tasks.image_classification: ['image'], | |||
| Tasks.summarization: ['text'], | |||
| Tasks.text_summarization: ['text'], | |||
| Tasks.text_classification: ['text', 'text2'], | |||
| Tasks.visual_grounding: ['image', 'text'], | |||
| Tasks.visual_question_answering: ['image', 'text'], | |||
| @@ -11,7 +11,7 @@ if TYPE_CHECKING: | |||
| FillMaskPoNetPreprocessor, | |||
| NLPPreprocessor, | |||
| NLPTokenizerPreprocessorBase, | |||
| PassageRankingPreprocessor, | |||
| TextRankingPreprocessor, | |||
| RelationExtractionPreprocessor, | |||
| SentenceEmbeddingPreprocessor, | |||
| SequenceClassificationPreprocessor, | |||
| @@ -33,7 +33,7 @@ else: | |||
| 'FillMaskPoNetPreprocessor', | |||
| 'NLPPreprocessor', | |||
| 'NLPTokenizerPreprocessorBase', | |||
| 'PassageRankingPreprocessor', | |||
| 'TextRankingPreprocessor', | |||
| 'RelationExtractionPreprocessor', | |||
| 'SentenceEmbeddingPreprocessor', | |||
| 'SequenceClassificationPreprocessor', | |||
| @@ -1,9 +1,10 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import os.path as osp | |||
| import re | |||
| from typing import Any, Dict, Iterable, List, Optional, Tuple, Union | |||
| from typing import Any, Dict, Optional, Tuple, Union | |||
| import json | |||
| import numpy as np | |||
| import sentencepiece as spm | |||
| import torch | |||
| @@ -13,8 +14,7 @@ from modelscope.metainfo import Models, Preprocessors | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.preprocessors.base import Preprocessor | |||
| from modelscope.preprocessors.builder import PREPROCESSORS | |||
| from modelscope.utils.config import (Config, ConfigFields, | |||
| use_task_specific_params) | |||
| from modelscope.utils.config import Config, ConfigFields | |||
| from modelscope.utils.constant import Fields, InputFields, ModeKeys, ModelFile | |||
| from modelscope.utils.hub import get_model_type, parse_label_mapping | |||
| from modelscope.utils.logger import get_logger | |||
| @@ -29,7 +29,7 @@ __all__ = [ | |||
| 'NLPPreprocessor', | |||
| 'FillMaskPoNetPreprocessor', | |||
| 'NLPTokenizerPreprocessorBase', | |||
| 'PassageRankingPreprocessor', | |||
| 'TextRankingPreprocessor', | |||
| 'RelationExtractionPreprocessor', | |||
| 'SentenceEmbeddingPreprocessor', | |||
| 'SequenceClassificationPreprocessor', | |||
| @@ -83,6 +83,15 @@ class NLPTokenizerPreprocessorBase(Preprocessor): | |||
| self._mode = mode | |||
| self.label = kwargs.pop('label', OutputKeys.LABEL) | |||
| self.use_fast = kwargs.pop('use_fast', None) | |||
| if self.use_fast is None and os.path.isfile( | |||
| os.path.join(model_dir, 'tokenizer_config.json')): | |||
| with open(os.path.join(model_dir, 'tokenizer_config.json'), | |||
| 'r') as f: | |||
| json_config = json.load(f) | |||
| self.use_fast = json_config.get('use_fast') | |||
| self.use_fast = False if self.use_fast is None else self.use_fast | |||
| self.label2id = None | |||
| if 'label2id' in kwargs: | |||
| self.label2id = kwargs.pop('label2id') | |||
| @@ -118,32 +127,23 @@ class NLPTokenizerPreprocessorBase(Preprocessor): | |||
| if model_type in (Models.structbert, Models.gpt3, Models.palm, | |||
| Models.plug): | |||
| from modelscope.models.nlp.structbert import SbertTokenizer, SbertTokenizerFast | |||
| return SbertTokenizer.from_pretrained( | |||
| model_dir | |||
| ) if self._mode == ModeKeys.INFERENCE else SbertTokenizerFast.from_pretrained( | |||
| model_dir) | |||
| tokenizer = SbertTokenizerFast if self.use_fast else SbertTokenizer | |||
| return tokenizer.from_pretrained(model_dir) | |||
| elif model_type == Models.veco: | |||
| from modelscope.models.nlp.veco import VecoTokenizer, VecoTokenizerFast | |||
| return VecoTokenizer.from_pretrained( | |||
| model_dir | |||
| ) if self._mode == ModeKeys.INFERENCE else VecoTokenizerFast.from_pretrained( | |||
| model_dir) | |||
| tokenizer = VecoTokenizerFast if self.use_fast else VecoTokenizer | |||
| return tokenizer.from_pretrained(model_dir) | |||
| elif model_type == Models.deberta_v2: | |||
| from modelscope.models.nlp.deberta_v2 import DebertaV2Tokenizer, DebertaV2TokenizerFast | |||
| return DebertaV2Tokenizer.from_pretrained( | |||
| model_dir | |||
| ) if self._mode == ModeKeys.INFERENCE else DebertaV2TokenizerFast.from_pretrained( | |||
| model_dir) | |||
| tokenizer = DebertaV2TokenizerFast if self.use_fast else DebertaV2Tokenizer | |||
| return tokenizer.from_pretrained(model_dir) | |||
| elif not self.is_transformer_based_model: | |||
| from transformers import BertTokenizer, BertTokenizerFast | |||
| return BertTokenizer.from_pretrained( | |||
| model_dir | |||
| ) if self._mode == ModeKeys.INFERENCE else BertTokenizerFast.from_pretrained( | |||
| model_dir) | |||
| tokenizer = BertTokenizerFast if self.use_fast else BertTokenizer | |||
| return tokenizer.from_pretrained(model_dir) | |||
| else: | |||
| return AutoTokenizer.from_pretrained( | |||
| model_dir, | |||
| use_fast=False if self._mode == ModeKeys.INFERENCE else True) | |||
| model_dir, use_fast=self.use_fast) | |||
| def __call__(self, data: Union[str, Tuple, Dict]) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| @@ -245,9 +245,9 @@ class NLPPreprocessor(NLPTokenizerPreprocessorBase): | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.passage_ranking) | |||
| class PassageRankingPreprocessor(NLPTokenizerPreprocessorBase): | |||
| """The tokenizer preprocessor used in passage ranking model. | |||
| Fields.nlp, module_name=Preprocessors.text_ranking) | |||
| class TextRankingPreprocessor(NLPTokenizerPreprocessorBase): | |||
| """The tokenizer preprocessor used in text-ranking model. | |||
| """ | |||
| def __init__(self, | |||
| @@ -593,9 +593,6 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| else: | |||
| self.is_split_into_words = self.tokenizer.init_kwargs.get( | |||
| 'is_split_into_words', False) | |||
| if 'label2id' in kwargs: | |||
| kwargs.pop('label2id') | |||
| self.tokenize_kwargs = kwargs | |||
| @type_assert(object, str) | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| @@ -13,7 +13,7 @@ class Database: | |||
| tokenizer, | |||
| table_file_path, | |||
| syn_dict_file_path, | |||
| is_use_sqlite=False): | |||
| is_use_sqlite=True): | |||
| self.tokenizer = tokenizer | |||
| self.is_use_sqlite = is_use_sqlite | |||
| if self.is_use_sqlite: | |||
| @@ -293,6 +293,7 @@ class SchemaLinker: | |||
| nlu_t, | |||
| tables, | |||
| col_syn_dict, | |||
| table_id=None, | |||
| history_sql=None): | |||
| """ | |||
| get linking between question and schema column | |||
| @@ -300,6 +301,9 @@ class SchemaLinker: | |||
| typeinfos = [] | |||
| numbers = re.findall(r'[-]?\d*\.\d+|[-]?\d+|\d+', nlu) | |||
| if table_id is not None and table_id in tables: | |||
| tables = {table_id: tables[table_id]} | |||
| # search schema link in every table | |||
| search_result_list = [] | |||
| for tablename in tables: | |||
| @@ -411,26 +415,25 @@ class SchemaLinker: | |||
| # get the match score of each table | |||
| match_score = self.get_table_match_score(nlu_t, schema_link) | |||
| # cal table_score | |||
| if history_sql is not None and 'from' in history_sql: | |||
| table_score = int(table['table_id'] == history_sql['from'][0]) | |||
| else: | |||
| table_score = 0 | |||
| search_result = { | |||
| 'table_id': | |||
| table['table_id'], | |||
| 'question_knowledge': | |||
| final_question, | |||
| 'header_knowledge': | |||
| final_header, | |||
| 'schema_link': | |||
| schema_link, | |||
| 'match_score': | |||
| match_score, | |||
| 'table_score': | |||
| int(table['table_id'] == history_sql['from'][0]) | |||
| if history_sql is not None else 0 | |||
| 'table_id': table['table_id'], | |||
| 'question_knowledge': final_question, | |||
| 'header_knowledge': final_header, | |||
| 'schema_link': schema_link, | |||
| 'match_score': match_score, | |||
| 'table_score': table_score | |||
| } | |||
| search_result_list.append(search_result) | |||
| search_result_list = sorted( | |||
| search_result_list, | |||
| key=lambda x: (x['match_score'], x['table_score']), | |||
| reverse=True)[0:4] | |||
| reverse=True)[0:1] | |||
| return search_result_list | |||
| @@ -95,6 +95,7 @@ class TableQuestionAnsweringPreprocessor(Preprocessor): | |||
| # tokenize question | |||
| question = data['question'] | |||
| table_id = data.get('table_id', None) | |||
| history_sql = data.get('history_sql', None) | |||
| nlu = question.lower() | |||
| nlu_t = self.tokenizer.tokenize(nlu) | |||
| @@ -106,6 +107,7 @@ class TableQuestionAnsweringPreprocessor(Preprocessor): | |||
| nlu_t=nlu_t, | |||
| tables=self.db.tables, | |||
| col_syn_dict=self.db.syn_dict, | |||
| table_id=table_id, | |||
| history_sql=history_sql) | |||
| # collect data | |||
| @@ -11,7 +11,7 @@ if TYPE_CHECKING: | |||
| ImagePortraitEnhancementTrainer, | |||
| MovieSceneSegmentationTrainer, ImageInpaintingTrainer) | |||
| from .multi_modal import CLIPTrainer | |||
| from .nlp import SequenceClassificationTrainer, PassageRankingTrainer | |||
| from .nlp import SequenceClassificationTrainer, TextRankingTrainer | |||
| from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer | |||
| from .trainer import EpochBasedTrainer | |||
| @@ -26,7 +26,7 @@ else: | |||
| 'ImageInpaintingTrainer' | |||
| ], | |||
| 'multi_modal': ['CLIPTrainer'], | |||
| 'nlp': ['SequenceClassificationTrainer', 'PassageRankingTrainer'], | |||
| 'nlp': ['SequenceClassificationTrainer', 'TextRankingTrainer'], | |||
| 'nlp_trainer': ['NlpEpochBasedTrainer', 'VecoTrainer'], | |||
| 'trainer': ['EpochBasedTrainer'] | |||
| } | |||
| @@ -6,12 +6,12 @@ from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .sequence_classification_trainer import SequenceClassificationTrainer | |||
| from .csanmt_translation_trainer import CsanmtTranslationTrainer | |||
| from .passage_ranking_trainer import PassageRankingTranier | |||
| from .text_ranking_trainer import TextRankingTranier | |||
| else: | |||
| _import_structure = { | |||
| 'sequence_classification_trainer': ['SequenceClassificationTrainer'], | |||
| 'csanmt_translation_trainer': ['CsanmtTranslationTrainer'], | |||
| 'passage_ranking_trainer': ['PassageRankingTrainer'] | |||
| 'text_ranking_trainer': ['TextRankingTrainer'] | |||
| } | |||
| import sys | |||
| @@ -8,6 +8,7 @@ import numpy as np | |||
| import torch | |||
| from torch import nn | |||
| from torch.utils.data import DataLoader, Dataset | |||
| from tqdm import tqdm | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.models.base import Model, TorchModel | |||
| @@ -42,8 +43,8 @@ class GroupCollator(): | |||
| return batch | |||
| @TRAINERS.register_module(module_name=Trainers.nlp_passage_ranking_trainer) | |||
| class PassageRankingTrainer(NlpEpochBasedTrainer): | |||
| @TRAINERS.register_module(module_name=Trainers.nlp_text_ranking_trainer) | |||
| class TextRankingTrainer(NlpEpochBasedTrainer): | |||
| def __init__( | |||
| self, | |||
| @@ -117,7 +118,7 @@ class PassageRankingTrainer(NlpEpochBasedTrainer): | |||
| Example: | |||
| {"accuracy": 0.5091743119266054, "f1": 0.673780487804878} | |||
| """ | |||
| from modelscope.models.nlp import PassageRanking | |||
| from modelscope.models.nlp import TextRanking | |||
| # get the raw online dataset | |||
| self.eval_dataloader = self._build_dataloader_with_dataset( | |||
| self.eval_dataset, | |||
| @@ -126,7 +127,7 @@ class PassageRankingTrainer(NlpEpochBasedTrainer): | |||
| # generate a standard dataloader | |||
| # generate a model | |||
| if checkpoint_path is not None: | |||
| model = PassageRanking.from_pretrained(checkpoint_path) | |||
| model = TextRanking.from_pretrained(checkpoint_path) | |||
| else: | |||
| model = self.model | |||
| @@ -141,7 +142,7 @@ class PassageRankingTrainer(NlpEpochBasedTrainer): | |||
| total_spent_time = 0.0 | |||
| device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |||
| model.to(device) | |||
| for _step, batch in enumerate(self.eval_dataloader): | |||
| for _step, batch in enumerate(tqdm(self.eval_dataloader)): | |||
| try: | |||
| batch = { | |||
| key: | |||
| @@ -103,7 +103,7 @@ class NLPTasks(object): | |||
| sentence_similarity = 'sentence-similarity' | |||
| text_classification = 'text-classification' | |||
| sentence_embedding = 'sentence-embedding' | |||
| passage_ranking = 'passage-ranking' | |||
| text_ranking = 'text-ranking' | |||
| relation_extraction = 'relation-extraction' | |||
| zero_shot = 'zero-shot' | |||
| translation = 'translation' | |||
| @@ -117,7 +117,7 @@ class NLPTasks(object): | |||
| table_question_answering = 'table-question-answering' | |||
| sentence_embedding = 'sentence-embedding' | |||
| fill_mask = 'fill-mask' | |||
| summarization = 'summarization' | |||
| text_summarization = 'text-summarization' | |||
| question_answering = 'question-answering' | |||
| zero_shot_classification = 'zero-shot-classification' | |||
| backbone = 'backbone' | |||
| @@ -151,7 +151,6 @@ class MultiModalTasks(object): | |||
| visual_entailment = 'visual-entailment' | |||
| video_multi_modal_embedding = 'video-multi-modal-embedding' | |||
| image_text_retrieval = 'image-text-retrieval' | |||
| ofa_ocr_recognition = 'ofa-ocr-recognition' | |||
| class TasksIODescriptions(object): | |||
| @@ -196,8 +196,7 @@ def build_from_cfg(cfg, | |||
| raise KeyError( | |||
| f'{obj_type} is not in the {registry.name}' | |||
| f' registry group {group_key}. Please make' | |||
| f' sure the correct version of 1qqQModelScope library is used.' | |||
| ) | |||
| f' sure the correct version of ModelScope library is used.') | |||
| obj_cls.group_key = group_key | |||
| elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): | |||
| obj_cls = obj_type | |||
| @@ -22,7 +22,7 @@ class TestExportSbertSequenceClassification(unittest.TestCase): | |||
| shutil.rmtree(self.tmp_dir) | |||
| super().tearDown() | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| @unittest.skip | |||
| def test_export_sbert_sequence_classification(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| print( | |||
| @@ -71,7 +71,7 @@ class MsDatasetTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| @require_torch | |||
| def test_to_torch_dataset_text(self): | |||
| model_id = 'damo/bert-base-sst2' | |||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' | |||
| nlp_model = Model.from_pretrained(model_id) | |||
| preprocessor = SequenceClassificationPreprocessor( | |||
| nlp_model.model_dir, | |||
| @@ -93,7 +93,7 @@ class MsDatasetTest(unittest.TestCase): | |||
| def test_to_tf_dataset_text(self): | |||
| import tensorflow as tf | |||
| tf.compat.v1.enable_eager_execution() | |||
| model_id = 'damo/bert-base-sst2' | |||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' | |||
| nlp_model = Model.from_pretrained(model_id) | |||
| preprocessor = SequenceClassificationPreprocessor( | |||
| nlp_model.model_dir, | |||
| @@ -80,164 +80,141 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, | |||
| all_models_info = [ | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1', | |||
| 'damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1', | |||
| 'wav_path': 'data/test/audios/asr_example.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': 'speech_paraformer_asr_nat-aishell1-pytorch', | |||
| 'model_id': 'damo/speech_paraformer_asr_nat-aishell1-pytorch', | |||
| 'wav_path': 'data/test/audios/asr_example.wav' | |||
| }, | |||
| { | |||
| 'model_id': 'damo/speech_paraformer_asr_nat-aishell2-pytorch', | |||
| 'wav_path': 'data/test/audios/asr_example.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1', | |||
| 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1', | |||
| 'wav_path': 'data/test/audios/asr_example.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1', | |||
| 'damo/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1', | |||
| 'wav_path': 'data/test/audios/asr_example_8K.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online', | |||
| 'damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online', | |||
| 'wav_path': 'data/test/audios/asr_example.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline', | |||
| 'damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline', | |||
| 'wav_path': 'data/test/audios/asr_example.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online', | |||
| 'damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online', | |||
| 'wav_path': 'data/test/audios/asr_example_8K.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline', | |||
| 'damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline', | |||
| 'wav_path': 'data/test/audios/asr_example_8K.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline', | |||
| 'damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline', | |||
| 'wav_path': 'data/test/audios/asr_example.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-online', | |||
| 'damo/speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-online', | |||
| 'wav_path': 'data/test/audios/asr_example_cn_en.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-offline', | |||
| 'damo/speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-offline', | |||
| 'wav_path': 'data/test/audios/asr_example_cn_en.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online', | |||
| 'damo/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online', | |||
| 'wav_path': 'data/test/audios/asr_example_cn_dialect.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline', | |||
| 'damo/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline', | |||
| 'wav_path': 'data/test/audios/asr_example_cn_dialect.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_paraformer_asr_nat-zh-cn-16k-common-vocab3444-tensorflow1-online', | |||
| 'damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab3444-tensorflow1-online', | |||
| 'wav_path': 'data/test/audios/asr_example.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_paraformer_asr_nat-zh-cn-8k-common-vocab3444-tensorflow1-online', | |||
| 'damo/speech_paraformer_asr_nat-zh-cn-8k-common-vocab3444-tensorflow1-online', | |||
| 'wav_path': 'data/test/audios/asr_example_8K.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline', | |||
| 'damo/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline', | |||
| 'wav_path': 'data/test/audios/asr_example_en.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-online', | |||
| 'damo/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-online', | |||
| 'wav_path': 'data/test/audios/asr_example_en.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline', | |||
| 'damo/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline', | |||
| 'wav_path': 'data/test/audios/asr_example_ru.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-online', | |||
| 'damo/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-online', | |||
| 'wav_path': 'data/test/audios/asr_example_ru.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline', | |||
| 'damo/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline', | |||
| 'wav_path': 'data/test/audios/asr_example_es.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-online', | |||
| 'damo/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-online', | |||
| 'wav_path': 'data/test/audios/asr_example_es.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline', | |||
| 'damo/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline', | |||
| 'wav_path': 'data/test/audios/asr_example_ko.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-online', | |||
| 'damo/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-online', | |||
| 'wav_path': 'data/test/audios/asr_example_ko.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-online', | |||
| 'damo/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-online', | |||
| 'wav_path': 'data/test/audios/asr_example_ja.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline', | |||
| 'damo/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline', | |||
| 'wav_path': 'data/test/audios/asr_example_ja.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online', | |||
| 'damo/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online', | |||
| 'wav_path': 'data/test/audios/asr_example_id.wav' | |||
| }, | |||
| { | |||
| 'model_group': 'damo', | |||
| 'model_id': | |||
| 'speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline', | |||
| 'damo/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline', | |||
| 'wav_path': 'data/test/audios/asr_example_id.wav' | |||
| }, | |||
| ] | |||
| @@ -404,7 +381,7 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase, | |||
| logger.info('Run ASR test with all models') | |||
| for item in self.all_models_info: | |||
| model_id = item['model_group'] + '/' + item['model_id'] | |||
| model_id = item['model_id'] | |||
| wav_path = item['wav_path'] | |||
| rec_result = self.run_pipeline( | |||
| model_id=model_id, audio_in=wav_path) | |||
| @@ -17,12 +17,12 @@ class TextGPT3GenerationTest(unittest.TestCase): | |||
| self.model_dir_13B = snapshot_download(self.model_id_13B) | |||
| self.input = '好的' | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| @unittest.skip('distributed gpt3 1.3B, skipped') | |||
| def test_gpt3_1_3B(self): | |||
| pipe = pipeline(Tasks.text_generation, model=self.model_id_1_3B) | |||
| print(pipe(self.input)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| @unittest.skip('distributed gpt3 2.7B, skipped') | |||
| def test_gpt3_2_7B(self): | |||
| pipe = pipeline(Tasks.text_generation, model=self.model_id_2_7B) | |||
| print(pipe(self.input)) | |||
| @@ -48,7 +48,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_ocr_recognize_with_name(self): | |||
| ocr_recognize = pipeline( | |||
| Tasks.ofa_ocr_recognition, | |||
| Tasks.ocr_recognition, | |||
| model='damo/ofa_ocr-recognition_scene_base_zh') | |||
| result = ocr_recognize('data/test/images/image_ocr_recognition.jpg') | |||
| print(result[OutputKeys.TEXT]) | |||
| @@ -75,7 +75,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| def test_run_with_summarization_with_model(self): | |||
| model = Model.from_pretrained( | |||
| 'damo/ofa_summarization_gigaword_large_en') | |||
| ofa_pipe = pipeline(Tasks.summarization, model=model) | |||
| ofa_pipe = pipeline(Tasks.text_summarization, model=model) | |||
| text = 'five-time world champion michelle kwan withdrew' + \ | |||
| 'from the #### us figure skating championships on wednesday ,' + \ | |||
| ' but will petition us skating officials for the chance to ' + \ | |||
| @@ -87,7 +87,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_summarization_with_name(self): | |||
| ofa_pipe = pipeline( | |||
| Tasks.summarization, | |||
| Tasks.text_summarization, | |||
| model='damo/ofa_summarization_gigaword_large_en') | |||
| text = 'five-time world champion michelle kwan withdrew' + \ | |||
| 'from the #### us figure skating championships on wednesday ,' + \ | |||
| @@ -15,7 +15,7 @@ from modelscope.utils.test_utils import test_level | |||
| class RelationExtractionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| def setUp(self) -> None: | |||
| self.task = Tasks.information_extraction | |||
| self.task = Tasks.relation_extraction | |||
| self.model_id = 'damo/nlp_bert_relation-extraction_chinese-base' | |||
| sentence = '高捷,祖籍江苏,本科毕业于东南大学' | |||
| @@ -28,7 +28,7 @@ class RelationExtractionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| pipeline1 = InformationExtractionPipeline( | |||
| model, preprocessor=tokenizer) | |||
| pipeline2 = pipeline( | |||
| Tasks.information_extraction, model=model, preprocessor=tokenizer) | |||
| Tasks.relation_extraction, model=model, preprocessor=tokenizer) | |||
| print(f'sentence: {self.sentence}\n' | |||
| f'pipeline1:{pipeline1(input=self.sentence)}') | |||
| print() | |||
| @@ -39,7 +39,7 @@ class RelationExtractionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| model = Model.from_pretrained(self.model_id) | |||
| tokenizer = RelationExtractionPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.information_extraction, | |||
| task=Tasks.relation_extraction, | |||
| model=model, | |||
| preprocessor=tokenizer) | |||
| print(pipeline_ins(input=self.sentence)) | |||
| @@ -47,12 +47,12 @@ class RelationExtractionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.information_extraction, model=self.model_id) | |||
| task=Tasks.relation_extraction, model=self.model_id) | |||
| print(pipeline_ins(input=self.sentence)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline(task=Tasks.information_extraction) | |||
| pipeline_ins = pipeline(task=Tasks.relation_extraction) | |||
| print(pipeline_ins(input=self.sentence)) | |||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||
| @@ -43,7 +43,7 @@ def tableqa_tracking_and_print_results_with_history( | |||
| print('sql text:', output_dict[OutputKeys.SQL_STRING]) | |||
| print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | |||
| print('query result:', output_dict[OutputKeys.QUERT_RESULT]) | |||
| print('json dumps', json.dumps(output_dict)) | |||
| print('json dumps', json.dumps(output_dict, ensure_ascii=False)) | |||
| print() | |||
| historical_queries = output_dict[OutputKeys.HISTORY] | |||
| @@ -66,10 +66,42 @@ def tableqa_tracking_and_print_results_without_history( | |||
| print('sql text:', output_dict[OutputKeys.SQL_STRING]) | |||
| print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | |||
| print('query result:', output_dict[OutputKeys.QUERT_RESULT]) | |||
| print('json dumps', json.dumps(output_dict)) | |||
| print('json dumps', json.dumps(output_dict, ensure_ascii=False)) | |||
| print() | |||
| def tableqa_tracking_and_print_results_with_tableid( | |||
| pipelines: List[TableQuestionAnsweringPipeline]): | |||
| test_case = { | |||
| 'utterance': [ | |||
| ['有哪些风险类型?', 'fund'], | |||
| ['风险类型有多少种?', 'reservoir'], | |||
| ['珠江流域的小(2)型水库的库容总量是多少?', 'reservoir'], | |||
| ['那平均值是多少?', 'reservoir'], | |||
| ['那水库的名称呢?', 'reservoir'], | |||
| ['换成中型的呢?', 'reservoir'], | |||
| ['枣庄营业厅的电话', 'business'], | |||
| ['那地址呢?', 'business'], | |||
| ['枣庄营业厅的电话和地址', 'business'], | |||
| ], | |||
| } | |||
| for p in pipelines: | |||
| historical_queries = None | |||
| for question, table_id in test_case['utterance']: | |||
| output_dict = p({ | |||
| 'question': question, | |||
| 'table_id': table_id, | |||
| 'history_sql': historical_queries | |||
| }) | |||
| print('question', question) | |||
| print('sql text:', output_dict[OutputKeys.SQL_STRING]) | |||
| print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | |||
| print('query result:', output_dict[OutputKeys.QUERT_RESULT]) | |||
| print('json dumps', json.dumps(output_dict, ensure_ascii=False)) | |||
| print() | |||
| historical_queries = output_dict[OutputKeys.HISTORY] | |||
| class TableQuestionAnswering(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| @@ -93,15 +125,27 @@ class TableQuestionAnswering(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| self.tokenizer = BertTokenizer( | |||
| os.path.join(model.model_dir, ModelFile.VOCAB_FILE)) | |||
| db = Database( | |||
| tokenizer=self.tokenizer, | |||
| table_file_path=[ | |||
| os.path.join(model.model_dir, 'databases', fname) | |||
| for fname in os.listdir( | |||
| os.path.join(model.model_dir, 'databases')) | |||
| ], | |||
| syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'), | |||
| is_use_sqlite=False) | |||
| preprocessor = TableQuestionAnsweringPreprocessor( | |||
| model_dir=model.model_dir) | |||
| model_dir=model.model_dir, db=db) | |||
| pipelines = [ | |||
| pipeline( | |||
| Tasks.table_question_answering, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| preprocessor=preprocessor, | |||
| db=db) | |||
| ] | |||
| tableqa_tracking_and_print_results_with_history(pipelines) | |||
| tableqa_tracking_and_print_results_with_tableid(pipelines) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_task(self): | |||
| @@ -132,7 +176,6 @@ class TableQuestionAnswering(unittest.TestCase): | |||
| db=db) | |||
| ] | |||
| tableqa_tracking_and_print_results_without_history(pipelines) | |||
| tableqa_tracking_and_print_results_with_history(pipelines) | |||
| if __name__ == '__main__': | |||
| @@ -1,100 +0,0 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.models import Model | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import SequenceClassificationPipeline | |||
| from modelscope.preprocessors import SequenceClassificationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
| from modelscope.utils.test_utils import test_level | |||
| class SequenceClassificationTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| sentence1 = 'i like this wonderful place' | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/bert-base-sst2' | |||
| self.task = Tasks.text_classification | |||
| def predict(self, pipeline_ins: SequenceClassificationPipeline): | |||
| from easynlp.appzoo import load_dataset | |||
| set = load_dataset('glue', 'sst2') | |||
| data = set['test']['sentence'][:3] | |||
| results = pipeline_ins(data[0]) | |||
| print(results) | |||
| results = pipeline_ins(data[1]) | |||
| print(results) | |||
| print(data) | |||
| def printDataset(self, dataset: MsDataset): | |||
| for i, r in enumerate(dataset): | |||
| if i > 10: | |||
| break | |||
| print(r) | |||
| # @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| @unittest.skip('nlp model does not support tensor input, skipped') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| preprocessor = SequenceClassificationPreprocessor( | |||
| model.model_dir, first_sequence='sentence', second_sequence=None) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.text_classification, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| print(f'sentence1: {self.sentence1}\n' | |||
| f'pipeline1:{pipeline_ins(input=self.sentence1)}') | |||
| # @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| @unittest.skip('nlp model does not support tensor input, skipped') | |||
| def test_run_with_model_name(self): | |||
| text_classification = pipeline( | |||
| task=Tasks.text_classification, model=self.model_id) | |||
| result = text_classification( | |||
| MsDataset.load( | |||
| 'xcopa', | |||
| subset_name='translation-et', | |||
| namespace='damotest', | |||
| split='test', | |||
| target='premise')) | |||
| self.printDataset(result) | |||
| # @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| @unittest.skip('nlp model does not support tensor input, skipped') | |||
| def test_run_with_default_model(self): | |||
| text_classification = pipeline(task=Tasks.text_classification) | |||
| result = text_classification( | |||
| MsDataset.load( | |||
| 'xcopa', | |||
| subset_name='translation-et', | |||
| namespace='damotest', | |||
| split='test', | |||
| target='premise')) | |||
| self.printDataset(result) | |||
| # @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| @unittest.skip('nlp model does not support tensor input, skipped') | |||
| def test_run_with_modelscope_dataset(self): | |||
| text_classification = pipeline(task=Tasks.text_classification) | |||
| # loaded from modelscope dataset | |||
| dataset = MsDataset.load( | |||
| 'xcopa', | |||
| subset_name='translation-et', | |||
| namespace='damotest', | |||
| split='test', | |||
| target='premise') | |||
| result = text_classification(dataset) | |||
| self.printDataset(result) | |||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||
| def test_demo_compatibility(self): | |||
| self.compatibility_check() | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -4,15 +4,15 @@ import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import PassageRanking | |||
| from modelscope.models.nlp import TextRanking | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import PassageRankingPipeline | |||
| from modelscope.preprocessors import PassageRankingPreprocessor | |||
| from modelscope.pipelines.nlp import TextRankingPipeline | |||
| from modelscope.preprocessors import TextRankingPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class PassageRankingTest(unittest.TestCase): | |||
| class TextRankingTest(unittest.TestCase): | |||
| model_id = 'damo/nlp_corom_passage-ranking_english-base' | |||
| inputs = { | |||
| 'source_sentence': ["how long it take to get a master's degree"], | |||
| @@ -27,11 +27,11 @@ class PassageRankingTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_by_direct_model_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| tokenizer = PassageRankingPreprocessor(cache_path) | |||
| model = PassageRanking.from_pretrained(cache_path) | |||
| pipeline1 = PassageRankingPipeline(model, preprocessor=tokenizer) | |||
| tokenizer = TextRankingPreprocessor(cache_path) | |||
| model = TextRanking.from_pretrained(cache_path) | |||
| pipeline1 = TextRankingPipeline(model, preprocessor=tokenizer) | |||
| pipeline2 = pipeline( | |||
| Tasks.passage_ranking, model=model, preprocessor=tokenizer) | |||
| Tasks.text_ranking, model=model, preprocessor=tokenizer) | |||
| print(f'sentence: {self.inputs}\n' | |||
| f'pipeline1:{pipeline1(input=self.inputs)}') | |||
| print() | |||
| @@ -40,20 +40,19 @@ class PassageRankingTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| tokenizer = PassageRankingPreprocessor(model.model_dir) | |||
| tokenizer = TextRankingPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.passage_ranking, model=model, preprocessor=tokenizer) | |||
| task=Tasks.text_ranking, model=model, preprocessor=tokenizer) | |||
| print(pipeline_ins(input=self.inputs)) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.passage_ranking, model=self.model_id) | |||
| pipeline_ins = pipeline(task=Tasks.text_ranking, model=self.model_id) | |||
| print(pipeline_ins(input=self.inputs)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline(task=Tasks.passage_ranking) | |||
| pipeline_ins = pipeline(task=Tasks.text_ranking) | |||
| print(pipeline_ins(input=self.inputs)) | |||
| @@ -38,7 +38,8 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| shutil.rmtree(self.tmp_dir) | |||
| super().tearDown() | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| @unittest.skip( | |||
| 'Skip testing trainer repeatable, because it\'s unstable in daily UT') | |||
| def test_trainer_repeatable(self): | |||
| import torch # noqa | |||
| @@ -41,7 +41,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| model_id, | |||
| train_dataset, | |||
| eval_dataset, | |||
| name=Trainers.nlp_passage_ranking_trainer, | |||
| name=Trainers.nlp_text_ranking_trainer, | |||
| cfg_modify_fn=None, | |||
| **kwargs): | |||
| kwargs = dict( | |||
| @@ -61,8 +61,8 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| def test_finetune_msmarco(self): | |||
| def cfg_modify_fn(cfg): | |||
| cfg.task = 'passage-ranking' | |||
| cfg['preprocessor'] = {'type': 'passage-ranking'} | |||
| cfg.task = 'text-ranking' | |||
| cfg['preprocessor'] = {'type': 'text-ranking'} | |||
| cfg.train.optimizer.lr = 2e-5 | |||
| cfg['dataset'] = { | |||
| 'train': { | |||
| @@ -105,7 +105,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| }, { | |||
| 'type': 'EvaluationHook', | |||
| 'by_epoch': False, | |||
| 'interval': 3000 | |||
| 'interval': 15 | |||
| }] | |||
| return cfg | |||
| @@ -114,18 +114,19 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| train_ds = ds['train'].to_hf_dataset() | |||
| dev_ds = ds['train'].to_hf_dataset() | |||
| model_id = 'damo/nlp_corom_passage-ranking_english-base' | |||
| self.finetune( | |||
| model_id='damo/nlp_corom_passage-ranking_english-base', | |||
| model_id=model_id, | |||
| train_dataset=train_ds, | |||
| eval_dataset=dev_ds, | |||
| cfg_modify_fn=cfg_modify_fn) | |||
| output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) | |||
| self.pipeline_passage_ranking(output_dir) | |||
| self.pipeline_text_ranking(output_dir) | |||
| def pipeline_passage_ranking(self, model_dir): | |||
| def pipeline_text_ranking(self, model_dir): | |||
| model = Model.from_pretrained(model_dir) | |||
| pipeline_ins = pipeline(task=Tasks.passage_ranking, model=model) | |||
| pipeline_ins = pipeline(task=Tasks.text_ranking, model=model) | |||
| print(pipeline_ins(input=self.inputs)) | |||
| @@ -37,13 +37,13 @@ class TestImagePortraitEnhancementTrainer(unittest.TestCase): | |||
| namespace='modelscope', | |||
| subset_name='default', | |||
| split='test', | |||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds | |||
| download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds | |||
| dataset_val = MsDataset.load( | |||
| 'image-portrait-enhancement-dataset', | |||
| namespace='modelscope', | |||
| subset_name='default', | |||
| split='test', | |||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds | |||
| download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds | |||
| self.dataset_train = ImagePortraitEnhancementDataset( | |||
| dataset_train, is_train=True) | |||
| @@ -169,11 +169,25 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
| cfg.preprocessor.label = 'label' | |||
| cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} | |||
| cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} | |||
| cfg.train.dataloader.batch_size_per_gpu = 2 | |||
| cfg.train.hooks = [{ | |||
| 'type': 'CheckpointHook', | |||
| 'interval': 3, | |||
| 'by_epoch': False, | |||
| }, { | |||
| 'type': 'TextLoggerHook', | |||
| 'interval': 1 | |||
| }, { | |||
| 'type': 'IterTimerHook' | |||
| }, { | |||
| 'type': 'EvaluationHook', | |||
| 'interval': 1 | |||
| }] | |||
| cfg.train.work_dir = self.tmp_dir | |||
| cfg_file = os.path.join(self.tmp_dir, 'config.json') | |||
| cfg.dump(cfg_file) | |||
| dataset = MsDataset.load('clue', subset_name='afqmc', split='train') | |||
| dataset = dataset.to_hf_dataset().select(range(128)) | |||
| dataset = dataset.to_hf_dataset().select(range(4)) | |||
| kwargs = dict( | |||
| model=model_id, | |||
| train_dataset=dataset, | |||
| @@ -190,7 +204,7 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
| PRIORITY = Priority.VERY_LOW | |||
| def after_iter(self, trainer): | |||
| if trainer.iter == 12: | |||
| if trainer.iter == 3: | |||
| raise MsRegressTool.EarlyStopError('Test finished.') | |||
| if 'EarlyStopHook' not in [ | |||
| @@ -207,12 +221,11 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| trainer = build_trainer(default_args=kwargs) | |||
| regress_tool = MsRegressTool(baseline=False) | |||
| with regress_tool.monitor_ms_train( | |||
| trainer, 'trainer_continue_train', level='strict'): | |||
| trainer.train(os.path.join(self.tmp_dir, 'iter_12.pth')) | |||
| trainer.train(os.path.join(self.tmp_dir, 'iter_3.pth')) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_trainer_with_model_and_args(self): | |||