diff --git a/modelscope/models/nlp/heads/infromation_extraction_head.py b/modelscope/models/nlp/heads/infromation_extraction_head.py index 6c3388f0..626f1b59 100644 --- a/modelscope/models/nlp/heads/infromation_extraction_head.py +++ b/modelscope/models/nlp/heads/infromation_extraction_head.py @@ -10,6 +10,8 @@ from modelscope.utils.constant import Tasks @HEADS.register_module( Tasks.information_extraction, module_name=Heads.information_extraction) +@HEADS.register_module( + Tasks.relation_extraction, module_name=Heads.information_extraction) class InformationExtractionHead(TorchHead): def __init__(self, **kwargs): diff --git a/modelscope/models/nlp/task_models/information_extraction.py b/modelscope/models/nlp/task_models/information_extraction.py index 0a7d5a47..a206c2fc 100644 --- a/modelscope/models/nlp/task_models/information_extraction.py +++ b/modelscope/models/nlp/task_models/information_extraction.py @@ -16,6 +16,8 @@ __all__ = ['InformationExtractionModel'] @MODELS.register_module( Tasks.information_extraction, module_name=TaskModels.information_extraction) +@MODELS.register_module( + Tasks.relation_extraction, module_name=TaskModels.information_extraction) class InformationExtractionModel(SingleBackboneTaskModelBase): def __init__(self, model_dir: str, *args, **kwargs): diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index f183afc1..aaea0bb6 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -31,6 +31,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.named_entity_recognition: (Pipelines.named_entity_recognition, 'damo/nlp_raner_named-entity-recognition_chinese-base-news'), + Tasks.relation_extraction: + (Pipelines.relation_extraction, + 'damo/nlp_bert_relation-extraction_chinese-base'), Tasks.information_extraction: (Pipelines.relation_extraction, 'damo/nlp_bert_relation-extraction_chinese-base'), diff --git a/modelscope/pipelines/nlp/information_extraction_pipeline.py b/modelscope/pipelines/nlp/information_extraction_pipeline.py index 763e941c..8ac85f43 100644 --- a/modelscope/pipelines/nlp/information_extraction_pipeline.py +++ b/modelscope/pipelines/nlp/information_extraction_pipeline.py @@ -17,6 +17,8 @@ __all__ = ['InformationExtractionPipeline'] @PIPELINES.register_module( Tasks.information_extraction, module_name=Pipelines.relation_extraction) +@PIPELINES.register_module( + Tasks.relation_extraction, module_name=Pipelines.relation_extraction) class InformationExtractionPipeline(Pipeline): def __init__(self, diff --git a/tests/pipelines/test_relation_extraction.py b/tests/pipelines/test_relation_extraction.py index 57d98f66..561eaf21 100644 --- a/tests/pipelines/test_relation_extraction.py +++ b/tests/pipelines/test_relation_extraction.py @@ -15,7 +15,7 @@ from modelscope.utils.test_utils import test_level class RelationExtractionTest(unittest.TestCase, DemoCompatibilityCheck): def setUp(self) -> None: - self.task = Tasks.information_extraction + self.task = Tasks.relation_extraction self.model_id = 'damo/nlp_bert_relation-extraction_chinese-base' sentence = '高捷,祖籍江苏,本科毕业于东南大学' @@ -28,7 +28,7 @@ class RelationExtractionTest(unittest.TestCase, DemoCompatibilityCheck): pipeline1 = InformationExtractionPipeline( model, preprocessor=tokenizer) pipeline2 = pipeline( - Tasks.information_extraction, model=model, preprocessor=tokenizer) + Tasks.relation_extraction, model=model, preprocessor=tokenizer) print(f'sentence: {self.sentence}\n' f'pipeline1:{pipeline1(input=self.sentence)}') print() @@ -39,7 +39,7 @@ class RelationExtractionTest(unittest.TestCase, DemoCompatibilityCheck): model = Model.from_pretrained(self.model_id) tokenizer = RelationExtractionPreprocessor(model.model_dir) pipeline_ins = pipeline( - task=Tasks.information_extraction, + task=Tasks.relation_extraction, model=model, preprocessor=tokenizer) print(pipeline_ins(input=self.sentence)) @@ -47,12 +47,12 @@ class RelationExtractionTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_model_name(self): pipeline_ins = pipeline( - task=Tasks.information_extraction, model=self.model_id) + task=Tasks.relation_extraction, model=self.model_id) print(pipeline_ins(input=self.sentence)) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_default_model(self): - pipeline_ins = pipeline(task=Tasks.information_extraction) + pipeline_ins = pipeline(task=Tasks.relation_extraction) print(pipeline_ins(input=self.sentence)) @unittest.skip('demo compatibility test is only enabled on a needed-basis')