1. 合并star和star3框架
2. 修改star和star3的model type
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10492793
master
| @@ -67,8 +67,9 @@ class Models(object): | |||||
| space_dst = 'space-dst' | space_dst = 'space-dst' | ||||
| space_intent = 'space-intent' | space_intent = 'space-intent' | ||||
| space_modeling = 'space-modeling' | space_modeling = 'space-modeling' | ||||
| star = 'star' | |||||
| star3 = 'star3' | |||||
| space_T_en = 'space-T-en' | |||||
| space_T_cn = 'space-T-cn' | |||||
| tcrf = 'transformer-crf' | tcrf = 'transformer-crf' | ||||
| transformer_softmax = 'transformer-softmax' | transformer_softmax = 'transformer-softmax' | ||||
| lcrf = 'lstm-crf' | lcrf = 'lstm-crf' | ||||
| @@ -24,8 +24,8 @@ import json | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| class Star3Config(object): | |||||
| """Configuration class to store the configuration of a `Star3Model`. | |||||
| class SpaceTCnConfig(object): | |||||
| """Configuration class to store the configuration of a `SpaceTCnModel`. | |||||
| """ | """ | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -40,10 +40,10 @@ class Star3Config(object): | |||||
| max_position_embeddings=512, | max_position_embeddings=512, | ||||
| type_vocab_size=2, | type_vocab_size=2, | ||||
| initializer_range=0.02): | initializer_range=0.02): | ||||
| """Constructs Star3Config. | |||||
| """Constructs SpaceTCnConfig. | |||||
| Args: | Args: | ||||
| vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `Star3Model`. | |||||
| vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `SpaceTCnConfig`. | |||||
| hidden_size: Size of the encoder layers and the pooler layer. | hidden_size: Size of the encoder layers and the pooler layer. | ||||
| num_hidden_layers: Number of hidden layers in the Transformer encoder. | num_hidden_layers: Number of hidden layers in the Transformer encoder. | ||||
| num_attention_heads: Number of attention heads for each attention layer in | num_attention_heads: Number of attention heads for each attention layer in | ||||
| @@ -59,7 +59,7 @@ class Star3Config(object): | |||||
| max_position_embeddings: The maximum sequence length that this model might | max_position_embeddings: The maximum sequence length that this model might | ||||
| ever be used with. Typically set this to something large just in case | ever be used with. Typically set this to something large just in case | ||||
| (e.g., 512 or 1024 or 2048). | (e.g., 512 or 1024 or 2048). | ||||
| type_vocab_size: The vocabulary size of the `token_type_ids` passed into `Star3Model`. | |||||
| type_vocab_size: The vocabulary size of the `token_type_ids` passed into `SpaceTCnConfig`. | |||||
| initializer_range: The sttdev of the truncated_normal_initializer for | initializer_range: The sttdev of the truncated_normal_initializer for | ||||
| initializing all weight matrices. | initializing all weight matrices. | ||||
| """ | """ | ||||
| @@ -89,15 +89,15 @@ class Star3Config(object): | |||||
| @classmethod | @classmethod | ||||
| def from_dict(cls, json_object): | def from_dict(cls, json_object): | ||||
| """Constructs a `Star3Config` from a Python dictionary of parameters.""" | |||||
| config = Star3Config(vocab_size_or_config_json_file=-1) | |||||
| """Constructs a `SpaceTCnConfig` from a Python dictionary of parameters.""" | |||||
| config = SpaceTCnConfig(vocab_size_or_config_json_file=-1) | |||||
| for key, value in json_object.items(): | for key, value in json_object.items(): | ||||
| config.__dict__[key] = value | config.__dict__[key] = value | ||||
| return config | return config | ||||
| @classmethod | @classmethod | ||||
| def from_json_file(cls, json_file): | def from_json_file(cls, json_file): | ||||
| """Constructs a `Star3Config` from a json file of parameters.""" | |||||
| """Constructs a `SpaceTCnConfig` from a json file of parameters.""" | |||||
| with open(json_file, 'r', encoding='utf-8') as reader: | with open(json_file, 'r', encoding='utf-8') as reader: | ||||
| text = reader.read() | text = reader.read() | ||||
| return cls.from_dict(json.loads(text)) | return cls.from_dict(json.loads(text)) | ||||
| @@ -27,7 +27,8 @@ import numpy as np | |||||
| import torch | import torch | ||||
| from torch import nn | from torch import nn | ||||
| from modelscope.models.nlp.star3.configuration_star3 import Star3Config | |||||
| from modelscope.models.nlp.space_T_cn.configuration_space_T_cn import \ | |||||
| SpaceTCnConfig | |||||
| from modelscope.utils.constant import ModelFile | from modelscope.utils.constant import ModelFile | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| @@ -609,9 +610,9 @@ class PreTrainedBertModel(nn.Module): | |||||
| def __init__(self, config, *inputs, **kwargs): | def __init__(self, config, *inputs, **kwargs): | ||||
| super(PreTrainedBertModel, self).__init__() | super(PreTrainedBertModel, self).__init__() | ||||
| if not isinstance(config, Star3Config): | |||||
| if not isinstance(config, SpaceTCnConfig): | |||||
| raise ValueError( | raise ValueError( | ||||
| 'Parameter config in `{}(config)` should be an instance of class `Star3Config`. ' | |||||
| 'Parameter config in `{}(config)` should be an instance of class `SpaceTCnConfig`. ' | |||||
| 'To create a model from a Google pretrained model use ' | 'To create a model from a Google pretrained model use ' | ||||
| '`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`'.format( | '`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`'.format( | ||||
| self.__class__.__name__, self.__class__.__name__)) | self.__class__.__name__, self.__class__.__name__)) | ||||
| @@ -676,7 +677,7 @@ class PreTrainedBertModel(nn.Module): | |||||
| serialization_dir = tempdir | serialization_dir = tempdir | ||||
| # Load config | # Load config | ||||
| config_file = os.path.join(serialization_dir, CONFIG_NAME) | config_file = os.path.join(serialization_dir, CONFIG_NAME) | ||||
| config = Star3Config.from_json_file(config_file) | |||||
| config = SpaceTCnConfig.from_json_file(config_file) | |||||
| logger.info('Model config {}'.format(config)) | logger.info('Model config {}'.format(config)) | ||||
| # Instantiate model. | # Instantiate model. | ||||
| model = cls(config, *inputs, **kwargs) | model = cls(config, *inputs, **kwargs) | ||||
| @@ -742,11 +743,11 @@ class PreTrainedBertModel(nn.Module): | |||||
| return model | return model | ||||
| class Star3Model(PreTrainedBertModel): | |||||
| """Star3Model model ("Bidirectional Embedding Representations from a Transformer pretrained on STAR3.0"). | |||||
| class SpaceTCnModel(PreTrainedBertModel): | |||||
| """SpaceTCnModel model ("Bidirectional Embedding Representations from a Transformer pretrained on STAR-T-CN"). | |||||
| Params: | Params: | ||||
| config: a Star3Config class instance with the configuration to build a new model | |||||
| config: a SpaceTCnConfig class instance with the configuration to build a new model | |||||
| Inputs: | Inputs: | ||||
| `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] | ||||
| @@ -780,16 +781,16 @@ class Star3Model(PreTrainedBertModel): | |||||
| input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | ||||
| token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) | ||||
| config = modeling.Star3Config(vocab_size_or_config_json_file=32000, hidden_size=768, | |||||
| config = modeling.SpaceTCnConfig(vocab_size_or_config_json_file=32000, hidden_size=768, | |||||
| num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | ||||
| model = modeling.Star3Model(config=config) | |||||
| model = modeling.SpaceTCnModel(config=config) | |||||
| all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) | ||||
| ``` | ``` | ||||
| """ | """ | ||||
| def __init__(self, config, schema_link_module='none'): | def __init__(self, config, schema_link_module='none'): | ||||
| super(Star3Model, self).__init__(config) | |||||
| super(SpaceTCnModel, self).__init__(config) | |||||
| self.embeddings = BertEmbeddings(config) | self.embeddings = BertEmbeddings(config) | ||||
| self.encoder = BertEncoder( | self.encoder = BertEncoder( | ||||
| config, schema_link_module=schema_link_module) | config, schema_link_module=schema_link_module) | ||||
| @@ -20,7 +20,7 @@ __all__ = ['StarForTextToSql'] | |||||
| @MODELS.register_module( | @MODELS.register_module( | ||||
| Tasks.conversational_text_to_sql, module_name=Models.star) | |||||
| Tasks.table_question_answering, module_name=Models.space_T_en) | |||||
| class StarForTextToSql(Model): | class StarForTextToSql(Model): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| @@ -3,27 +3,25 @@ | |||||
| import os | import os | ||||
| from typing import Dict | from typing import Dict | ||||
| import json | |||||
| import numpy | import numpy | ||||
| import torch | import torch | ||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| import tqdm | |||||
| from transformers import BertTokenizer | from transformers import BertTokenizer | ||||
| from modelscope.metainfo import Models | from modelscope.metainfo import Models | ||||
| from modelscope.models.base import Model, Tensor | from modelscope.models.base import Model, Tensor | ||||
| from modelscope.models.builder import MODELS | from modelscope.models.builder import MODELS | ||||
| from modelscope.models.nlp.star3.configuration_star3 import Star3Config | |||||
| from modelscope.models.nlp.star3.modeling_star3 import Seq2SQL, Star3Model | |||||
| from modelscope.preprocessors.star3.fields.struct import Constant | |||||
| from modelscope.preprocessors.space_T_cn.fields.struct import Constant | |||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| from modelscope.utils.device import verify_device | from modelscope.utils.device import verify_device | ||||
| from .space_T_cn.configuration_space_T_cn import SpaceTCnConfig | |||||
| from .space_T_cn.modeling_space_T_cn import Seq2SQL, SpaceTCnModel | |||||
| __all__ = ['TableQuestionAnswering'] | __all__ = ['TableQuestionAnswering'] | ||||
| @MODELS.register_module( | @MODELS.register_module( | ||||
| Tasks.table_question_answering, module_name=Models.star3) | |||||
| Tasks.table_question_answering, module_name=Models.space_T_cn) | |||||
| class TableQuestionAnswering(Model): | class TableQuestionAnswering(Model): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| @@ -43,9 +41,9 @@ class TableQuestionAnswering(Model): | |||||
| os.path.join(self.model_dir, ModelFile.TORCH_MODEL_BIN_FILE), | os.path.join(self.model_dir, ModelFile.TORCH_MODEL_BIN_FILE), | ||||
| map_location='cpu') | map_location='cpu') | ||||
| self.backbone_config = Star3Config.from_json_file( | |||||
| self.backbone_config = SpaceTCnConfig.from_json_file( | |||||
| os.path.join(self.model_dir, ModelFile.CONFIGURATION)) | os.path.join(self.model_dir, ModelFile.CONFIGURATION)) | ||||
| self.backbone_model = Star3Model( | |||||
| self.backbone_model = SpaceTCnModel( | |||||
| config=self.backbone_config, schema_link_module='rat') | config=self.backbone_config, schema_link_module='rat') | ||||
| self.backbone_model.load_state_dict(state_dict['backbone_model']) | self.backbone_model.load_state_dict(state_dict['backbone_model']) | ||||
| @@ -606,21 +606,12 @@ TASK_OUTPUTS = { | |||||
| # } | # } | ||||
| Tasks.task_oriented_conversation: [OutputKeys.OUTPUT], | Tasks.task_oriented_conversation: [OutputKeys.OUTPUT], | ||||
| # conversational text-to-sql result for single sample | |||||
| # { | |||||
| # "text": "SELECT shop.Name FROM shop." | |||||
| # } | |||||
| Tasks.conversational_text_to_sql: [OutputKeys.TEXT], | |||||
| # table-question-answering result for single sample | # table-question-answering result for single sample | ||||
| # { | # { | ||||
| # "sql": "SELECT shop.Name FROM shop." | # "sql": "SELECT shop.Name FROM shop." | ||||
| # "sql_history": {sel: 0, agg: 0, conds: [[0, 0, 'val']]} | # "sql_history": {sel: 0, agg: 0, conds: [[0, 0, 'val']]} | ||||
| # } | # } | ||||
| Tasks.table_question_answering: [ | |||||
| OutputKeys.SQL_STRING, OutputKeys.SQL_QUERY, OutputKeys.HISTORY, | |||||
| OutputKeys.QUERT_RESULT | |||||
| ], | |||||
| Tasks.table_question_answering: [OutputKeys.OUTPUT], | |||||
| # ============ audio tasks =================== | # ============ audio tasks =================== | ||||
| # asr result for single sample | # asr result for single sample | ||||
| @@ -69,9 +69,6 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/nlp_space_dialog-modeling'), | 'damo/nlp_space_dialog-modeling'), | ||||
| Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, | Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, | ||||
| 'damo/nlp_space_dialog-state-tracking'), | 'damo/nlp_space_dialog-state-tracking'), | ||||
| Tasks.conversational_text_to_sql: | |||||
| (Pipelines.conversational_text_to_sql, | |||||
| 'damo/nlp_star_conversational-text-to-sql'), | |||||
| Tasks.table_question_answering: | Tasks.table_question_answering: | ||||
| (Pipelines.table_question_answering_pipeline, | (Pipelines.table_question_answering_pipeline, | ||||
| 'damo/nlp-convai-text2sql-pretrain-cn'), | 'damo/nlp-convai-text2sql-pretrain-cn'), | ||||
| @@ -19,7 +19,7 @@ __all__ = ['ConversationalTextToSqlPipeline'] | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.conversational_text_to_sql, | |||||
| Tasks.table_question_answering, | |||||
| module_name=Pipelines.conversational_text_to_sql) | module_name=Pipelines.conversational_text_to_sql) | ||||
| class ConversationalTextToSqlPipeline(Pipeline): | class ConversationalTextToSqlPipeline(Pipeline): | ||||
| @@ -62,7 +62,7 @@ class ConversationalTextToSqlPipeline(Pipeline): | |||||
| Dict[str, str]: the prediction results | Dict[str, str]: the prediction results | ||||
| """ | """ | ||||
| sql = Example.evaluator.obtain_sql(inputs['predict'][0], inputs['db']) | sql = Example.evaluator.obtain_sql(inputs['predict'][0], inputs['db']) | ||||
| result = {OutputKeys.TEXT: sql} | |||||
| result = {OutputKeys.OUTPUT: {OutputKeys.TEXT: sql}} | |||||
| return result | return result | ||||
| def _collate_fn(self, data): | def _collate_fn(self, data): | ||||
| @@ -13,8 +13,9 @@ from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Pipeline | from modelscope.pipelines.base import Pipeline | ||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import TableQuestionAnsweringPreprocessor | from modelscope.preprocessors import TableQuestionAnsweringPreprocessor | ||||
| from modelscope.preprocessors.star3.fields.database import Database | |||||
| from modelscope.preprocessors.star3.fields.struct import Constant, SQLQuery | |||||
| from modelscope.preprocessors.space_T_cn.fields.database import Database | |||||
| from modelscope.preprocessors.space_T_cn.fields.struct import (Constant, | |||||
| SQLQuery) | |||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| __all__ = ['TableQuestionAnsweringPipeline'] | __all__ = ['TableQuestionAnsweringPipeline'] | ||||
| @@ -320,7 +321,7 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||||
| OutputKeys.QUERT_RESULT: tabledata, | OutputKeys.QUERT_RESULT: tabledata, | ||||
| } | } | ||||
| return output | |||||
| return {OutputKeys.OUTPUT: output} | |||||
| def _collate_fn(self, data): | def _collate_fn(self, data): | ||||
| return data | return data | ||||
| @@ -40,7 +40,7 @@ if TYPE_CHECKING: | |||||
| DialogStateTrackingPreprocessor) | DialogStateTrackingPreprocessor) | ||||
| from .video import ReadVideoData, MovieSceneSegmentationPreprocessor | from .video import ReadVideoData, MovieSceneSegmentationPreprocessor | ||||
| from .star import ConversationalTextToSqlPreprocessor | from .star import ConversationalTextToSqlPreprocessor | ||||
| from .star3 import TableQuestionAnsweringPreprocessor | |||||
| from .space_T_cn import TableQuestionAnsweringPreprocessor | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| @@ -81,7 +81,7 @@ else: | |||||
| 'DialogStateTrackingPreprocessor', 'InputFeatures' | 'DialogStateTrackingPreprocessor', 'InputFeatures' | ||||
| ], | ], | ||||
| 'star': ['ConversationalTextToSqlPreprocessor'], | 'star': ['ConversationalTextToSqlPreprocessor'], | ||||
| 'star3': ['TableQuestionAnsweringPreprocessor'], | |||||
| 'space_T_cn': ['TableQuestionAnsweringPreprocessor'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -4,7 +4,7 @@ import sqlite3 | |||||
| import json | import json | ||||
| import tqdm | import tqdm | ||||
| from modelscope.preprocessors.star3.fields.struct import Trie | |||||
| from modelscope.preprocessors.space_T_cn.fields.struct import Trie | |||||
| class Database: | class Database: | ||||
| @@ -1,7 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import re | import re | ||||
| from modelscope.preprocessors.star3.fields.struct import TypeInfo | |||||
| from modelscope.preprocessors.space_T_cn.fields.struct import TypeInfo | |||||
| class SchemaLinker: | class SchemaLinker: | ||||
| @@ -8,8 +8,8 @@ from transformers import BertTokenizer | |||||
| from modelscope.metainfo import Preprocessors | from modelscope.metainfo import Preprocessors | ||||
| from modelscope.preprocessors.base import Preprocessor | from modelscope.preprocessors.base import Preprocessor | ||||
| from modelscope.preprocessors.builder import PREPROCESSORS | from modelscope.preprocessors.builder import PREPROCESSORS | ||||
| from modelscope.preprocessors.star3.fields.database import Database | |||||
| from modelscope.preprocessors.star3.fields.schema_link import SchemaLinker | |||||
| from modelscope.preprocessors.space_T_cn.fields.database import Database | |||||
| from modelscope.preprocessors.space_T_cn.fields.schema_link import SchemaLinker | |||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import Fields, ModelFile | from modelscope.utils.constant import Fields, ModelFile | ||||
| from modelscope.utils.type_assert import type_assert | from modelscope.utils.type_assert import type_assert | ||||
| @@ -123,7 +123,6 @@ class NLPTasks(object): | |||||
| backbone = 'backbone' | backbone = 'backbone' | ||||
| text_error_correction = 'text-error-correction' | text_error_correction = 'text-error-correction' | ||||
| faq_question_answering = 'faq-question-answering' | faq_question_answering = 'faq-question-answering' | ||||
| conversational_text_to_sql = 'conversational-text-to-sql' | |||||
| information_extraction = 'information-extraction' | information_extraction = 'information-extraction' | ||||
| document_segmentation = 'document-segmentation' | document_segmentation = 'document-segmentation' | ||||
| feature_extraction = 'feature-extraction' | feature_extraction = 'feature-extraction' | ||||
| @@ -20,7 +20,7 @@ def text2sql_tracking_and_print_results( | |||||
| results = p(case) | results = p(case) | ||||
| print({'question': item}) | print({'question': item}) | ||||
| print(results) | print(results) | ||||
| last_sql = results['text'] | |||||
| last_sql = results[OutputKeys.OUTPUT][OutputKeys.TEXT] | |||||
| history.append(item) | history.append(item) | ||||
| @@ -16,7 +16,7 @@ from modelscope.utils.test_utils import test_level | |||||
| class ConversationalTextToSql(unittest.TestCase, DemoCompatibilityCheck): | class ConversationalTextToSql(unittest.TestCase, DemoCompatibilityCheck): | ||||
| def setUp(self) -> None: | def setUp(self) -> None: | ||||
| self.task = Tasks.conversational_text_to_sql | |||||
| self.task = Tasks.table_question_answering | |||||
| self.model_id = 'damo/nlp_star_conversational-text-to-sql' | self.model_id = 'damo/nlp_star_conversational-text-to-sql' | ||||
| model_id = 'damo/nlp_star_conversational-text-to-sql' | model_id = 'damo/nlp_star_conversational-text-to-sql' | ||||
| @@ -66,11 +66,6 @@ class ConversationalTextToSql(unittest.TestCase, DemoCompatibilityCheck): | |||||
| pipelines = [pipeline(task=self.task, model=self.model_id)] | pipelines = [pipeline(task=self.task, model=self.model_id)] | ||||
| text2sql_tracking_and_print_results(self.test_case, pipelines) | text2sql_tracking_and_print_results(self.test_case, pipelines) | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_default_model(self): | |||||
| pipelines = [pipeline(task=self.task)] | |||||
| text2sql_tracking_and_print_results(self.test_case, pipelines) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_demo_compatibility(self): | def test_demo_compatibility(self): | ||||
| self.compatibility_check() | self.compatibility_check() | ||||
| @@ -12,7 +12,7 @@ from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | from modelscope.pipelines import pipeline | ||||
| from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline | from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline | ||||
| from modelscope.preprocessors import TableQuestionAnsweringPreprocessor | from modelscope.preprocessors import TableQuestionAnsweringPreprocessor | ||||
| from modelscope.preprocessors.star3.fields.database import Database | |||||
| from modelscope.preprocessors.space_T_cn.fields.database import Database | |||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| @@ -38,7 +38,7 @@ def tableqa_tracking_and_print_results_with_history( | |||||
| output_dict = p({ | output_dict = p({ | ||||
| 'question': question, | 'question': question, | ||||
| 'history_sql': historical_queries | 'history_sql': historical_queries | ||||
| }) | |||||
| })[OutputKeys.OUTPUT] | |||||
| print('question', question) | print('question', question) | ||||
| print('sql text:', output_dict[OutputKeys.SQL_STRING]) | print('sql text:', output_dict[OutputKeys.SQL_STRING]) | ||||
| print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | ||||
| @@ -61,7 +61,7 @@ def tableqa_tracking_and_print_results_without_history( | |||||
| } | } | ||||
| for p in pipelines: | for p in pipelines: | ||||
| for question in test_case['utterance']: | for question in test_case['utterance']: | ||||
| output_dict = p({'question': question}) | |||||
| output_dict = p({'question': question})[OutputKeys.OUTPUT] | |||||
| print('question', question) | print('question', question) | ||||
| print('sql text:', output_dict[OutputKeys.SQL_STRING]) | print('sql text:', output_dict[OutputKeys.SQL_STRING]) | ||||
| print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | ||||
| @@ -92,7 +92,7 @@ def tableqa_tracking_and_print_results_with_tableid( | |||||
| 'question': question, | 'question': question, | ||||
| 'table_id': table_id, | 'table_id': table_id, | ||||
| 'history_sql': historical_queries | 'history_sql': historical_queries | ||||
| }) | |||||
| })[OutputKeys.OUTPUT] | |||||
| print('question', question) | print('question', question) | ||||
| print('sql text:', output_dict[OutputKeys.SQL_STRING]) | print('sql text:', output_dict[OutputKeys.SQL_STRING]) | ||||
| print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | ||||
| @@ -147,11 +147,6 @@ class TableQuestionAnswering(unittest.TestCase): | |||||
| ] | ] | ||||
| tableqa_tracking_and_print_results_with_tableid(pipelines) | tableqa_tracking_and_print_results_with_tableid(pipelines) | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_from_task(self): | |||||
| pipelines = [pipeline(Tasks.table_question_answering, self.model_id)] | |||||
| tableqa_tracking_and_print_results_with_history(pipelines) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_run_with_model_from_modelhub_with_other_classes(self): | def test_run_with_model_from_modelhub_with_other_classes(self): | ||||
| model = Model.from_pretrained(self.model_id) | model = Model.from_pretrained(self.model_id) | ||||