Browse Source

[to #42322933] change star3 to space_T_cn

1. 合并star和star3框架
2. 修改star和star3的model type
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10492793
master
caorongyu.cry yingda.chen 3 years ago
parent
commit
35644fa0a7
21 changed files with 48 additions and 70 deletions
  1. +3
    -2
      modelscope/metainfo.py
  2. +0
    -0
      modelscope/models/nlp/space_T_cn/__init__.py
  3. +8
    -8
      modelscope/models/nlp/space_T_cn/configuration_space_T_cn.py
  4. +11
    -10
      modelscope/models/nlp/space_T_cn/modeling_space_T_cn.py
  5. +1
    -1
      modelscope/models/nlp/star_text_to_sql.py
  6. +6
    -8
      modelscope/models/nlp/table_question_answering.py
  7. +1
    -10
      modelscope/outputs.py
  8. +0
    -3
      modelscope/pipelines/builder.py
  9. +2
    -2
      modelscope/pipelines/nlp/conversational_text_to_sql_pipeline.py
  10. +4
    -3
      modelscope/pipelines/nlp/table_question_answering_pipeline.py
  11. +2
    -2
      modelscope/preprocessors/__init__.py
  12. +0
    -0
      modelscope/preprocessors/space_T_cn/__init__.py
  13. +0
    -0
      modelscope/preprocessors/space_T_cn/fields/__init__.py
  14. +1
    -1
      modelscope/preprocessors/space_T_cn/fields/database.py
  15. +1
    -1
      modelscope/preprocessors/space_T_cn/fields/schema_link.py
  16. +0
    -0
      modelscope/preprocessors/space_T_cn/fields/struct.py
  17. +2
    -2
      modelscope/preprocessors/space_T_cn/table_question_answering_preprocessor.py
  18. +0
    -1
      modelscope/utils/constant.py
  19. +1
    -1
      modelscope/utils/nlp/nlp_utils.py
  20. +1
    -6
      tests/pipelines/test_conversational_text_to_sql.py
  21. +4
    -9
      tests/pipelines/test_table_question_answering.py

+ 3
- 2
modelscope/metainfo.py View File

@@ -67,8 +67,9 @@ class Models(object):
space_dst = 'space-dst' space_dst = 'space-dst'
space_intent = 'space-intent' space_intent = 'space-intent'
space_modeling = 'space-modeling' space_modeling = 'space-modeling'
star = 'star'
star3 = 'star3'
space_T_en = 'space-T-en'
space_T_cn = 'space-T-cn'

tcrf = 'transformer-crf' tcrf = 'transformer-crf'
transformer_softmax = 'transformer-softmax' transformer_softmax = 'transformer-softmax'
lcrf = 'lstm-crf' lcrf = 'lstm-crf'


modelscope/models/nlp/star3/__init__.py → modelscope/models/nlp/space_T_cn/__init__.py View File


modelscope/models/nlp/star3/configuration_star3.py → modelscope/models/nlp/space_T_cn/configuration_space_T_cn.py View File

@@ -24,8 +24,8 @@ import json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)




class Star3Config(object):
"""Configuration class to store the configuration of a `Star3Model`.
class SpaceTCnConfig(object):
"""Configuration class to store the configuration of a `SpaceTCnModel`.
""" """


def __init__(self, def __init__(self,
@@ -40,10 +40,10 @@ class Star3Config(object):
max_position_embeddings=512, max_position_embeddings=512,
type_vocab_size=2, type_vocab_size=2,
initializer_range=0.02): initializer_range=0.02):
"""Constructs Star3Config.
"""Constructs SpaceTCnConfig.


Args: Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `Star3Model`.
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `SpaceTCnConfig`.
hidden_size: Size of the encoder layers and the pooler layer. hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder. num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in num_attention_heads: Number of attention heads for each attention layer in
@@ -59,7 +59,7 @@ class Star3Config(object):
max_position_embeddings: The maximum sequence length that this model might max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048). (e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into `Star3Model`.
type_vocab_size: The vocabulary size of the `token_type_ids` passed into `SpaceTCnConfig`.
initializer_range: The sttdev of the truncated_normal_initializer for initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
""" """
@@ -89,15 +89,15 @@ class Star3Config(object):


@classmethod @classmethod
def from_dict(cls, json_object): def from_dict(cls, json_object):
"""Constructs a `Star3Config` from a Python dictionary of parameters."""
config = Star3Config(vocab_size_or_config_json_file=-1)
"""Constructs a `SpaceTCnConfig` from a Python dictionary of parameters."""
config = SpaceTCnConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items(): for key, value in json_object.items():
config.__dict__[key] = value config.__dict__[key] = value
return config return config


@classmethod @classmethod
def from_json_file(cls, json_file): def from_json_file(cls, json_file):
"""Constructs a `Star3Config` from a json file of parameters."""
"""Constructs a `SpaceTCnConfig` from a json file of parameters."""
with open(json_file, 'r', encoding='utf-8') as reader: with open(json_file, 'r', encoding='utf-8') as reader:
text = reader.read() text = reader.read()
return cls.from_dict(json.loads(text)) return cls.from_dict(json.loads(text))

modelscope/models/nlp/star3/modeling_star3.py → modelscope/models/nlp/space_T_cn/modeling_space_T_cn.py View File

@@ -27,7 +27,8 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn


from modelscope.models.nlp.star3.configuration_star3 import Star3Config
from modelscope.models.nlp.space_T_cn.configuration_space_T_cn import \
SpaceTCnConfig
from modelscope.utils.constant import ModelFile from modelscope.utils.constant import ModelFile
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger


@@ -609,9 +610,9 @@ class PreTrainedBertModel(nn.Module):


def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super(PreTrainedBertModel, self).__init__() super(PreTrainedBertModel, self).__init__()
if not isinstance(config, Star3Config):
if not isinstance(config, SpaceTCnConfig):
raise ValueError( raise ValueError(
'Parameter config in `{}(config)` should be an instance of class `Star3Config`. '
'Parameter config in `{}(config)` should be an instance of class `SpaceTCnConfig`. '
'To create a model from a Google pretrained model use ' 'To create a model from a Google pretrained model use '
'`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`'.format( '`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`'.format(
self.__class__.__name__, self.__class__.__name__)) self.__class__.__name__, self.__class__.__name__))
@@ -676,7 +677,7 @@ class PreTrainedBertModel(nn.Module):
serialization_dir = tempdir serialization_dir = tempdir
# Load config # Load config
config_file = os.path.join(serialization_dir, CONFIG_NAME) config_file = os.path.join(serialization_dir, CONFIG_NAME)
config = Star3Config.from_json_file(config_file)
config = SpaceTCnConfig.from_json_file(config_file)
logger.info('Model config {}'.format(config)) logger.info('Model config {}'.format(config))
# Instantiate model. # Instantiate model.
model = cls(config, *inputs, **kwargs) model = cls(config, *inputs, **kwargs)
@@ -742,11 +743,11 @@ class PreTrainedBertModel(nn.Module):
return model return model




class Star3Model(PreTrainedBertModel):
"""Star3Model model ("Bidirectional Embedding Representations from a Transformer pretrained on STAR3.0").
class SpaceTCnModel(PreTrainedBertModel):
"""SpaceTCnModel model ("Bidirectional Embedding Representations from a Transformer pretrained on STAR-T-CN").


Params: Params:
config: a Star3Config class instance with the configuration to build a new model
config: a SpaceTCnConfig class instance with the configuration to build a new model


Inputs: Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
@@ -780,16 +781,16 @@ class Star3Model(PreTrainedBertModel):
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])


config = modeling.Star3Config(vocab_size_or_config_json_file=32000, hidden_size=768,
config = modeling.SpaceTCnConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)


model = modeling.Star3Model(config=config)
model = modeling.SpaceTCnModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
``` ```
""" """


def __init__(self, config, schema_link_module='none'): def __init__(self, config, schema_link_module='none'):
super(Star3Model, self).__init__(config)
super(SpaceTCnModel, self).__init__(config)
self.embeddings = BertEmbeddings(config) self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder( self.encoder = BertEncoder(
config, schema_link_module=schema_link_module) config, schema_link_module=schema_link_module)

+ 1
- 1
modelscope/models/nlp/star_text_to_sql.py View File

@@ -20,7 +20,7 @@ __all__ = ['StarForTextToSql']




@MODELS.register_module( @MODELS.register_module(
Tasks.conversational_text_to_sql, module_name=Models.star)
Tasks.table_question_answering, module_name=Models.space_T_en)
class StarForTextToSql(Model): class StarForTextToSql(Model):


def __init__(self, model_dir: str, *args, **kwargs): def __init__(self, model_dir: str, *args, **kwargs):


+ 6
- 8
modelscope/models/nlp/table_question_answering.py View File

@@ -3,27 +3,25 @@
import os import os
from typing import Dict from typing import Dict


import json
import numpy import numpy
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tqdm
from transformers import BertTokenizer from transformers import BertTokenizer


from modelscope.metainfo import Models from modelscope.metainfo import Models
from modelscope.models.base import Model, Tensor from modelscope.models.base import Model, Tensor
from modelscope.models.builder import MODELS from modelscope.models.builder import MODELS
from modelscope.models.nlp.star3.configuration_star3 import Star3Config
from modelscope.models.nlp.star3.modeling_star3 import Seq2SQL, Star3Model
from modelscope.preprocessors.star3.fields.struct import Constant
from modelscope.preprocessors.space_T_cn.fields.struct import Constant
from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.device import verify_device from modelscope.utils.device import verify_device
from .space_T_cn.configuration_space_T_cn import SpaceTCnConfig
from .space_T_cn.modeling_space_T_cn import Seq2SQL, SpaceTCnModel


__all__ = ['TableQuestionAnswering'] __all__ = ['TableQuestionAnswering']




@MODELS.register_module( @MODELS.register_module(
Tasks.table_question_answering, module_name=Models.star3)
Tasks.table_question_answering, module_name=Models.space_T_cn)
class TableQuestionAnswering(Model): class TableQuestionAnswering(Model):


def __init__(self, model_dir: str, *args, **kwargs): def __init__(self, model_dir: str, *args, **kwargs):
@@ -43,9 +41,9 @@ class TableQuestionAnswering(Model):
os.path.join(self.model_dir, ModelFile.TORCH_MODEL_BIN_FILE), os.path.join(self.model_dir, ModelFile.TORCH_MODEL_BIN_FILE),
map_location='cpu') map_location='cpu')


self.backbone_config = Star3Config.from_json_file(
self.backbone_config = SpaceTCnConfig.from_json_file(
os.path.join(self.model_dir, ModelFile.CONFIGURATION)) os.path.join(self.model_dir, ModelFile.CONFIGURATION))
self.backbone_model = Star3Model(
self.backbone_model = SpaceTCnModel(
config=self.backbone_config, schema_link_module='rat') config=self.backbone_config, schema_link_module='rat')
self.backbone_model.load_state_dict(state_dict['backbone_model']) self.backbone_model.load_state_dict(state_dict['backbone_model'])




+ 1
- 10
modelscope/outputs.py View File

@@ -606,21 +606,12 @@ TASK_OUTPUTS = {
# } # }
Tasks.task_oriented_conversation: [OutputKeys.OUTPUT], Tasks.task_oriented_conversation: [OutputKeys.OUTPUT],


# conversational text-to-sql result for single sample
# {
# "text": "SELECT shop.Name FROM shop."
# }
Tasks.conversational_text_to_sql: [OutputKeys.TEXT],

# table-question-answering result for single sample # table-question-answering result for single sample
# { # {
# "sql": "SELECT shop.Name FROM shop." # "sql": "SELECT shop.Name FROM shop."
# "sql_history": {sel: 0, agg: 0, conds: [[0, 0, 'val']]} # "sql_history": {sel: 0, agg: 0, conds: [[0, 0, 'val']]}
# } # }
Tasks.table_question_answering: [
OutputKeys.SQL_STRING, OutputKeys.SQL_QUERY, OutputKeys.HISTORY,
OutputKeys.QUERT_RESULT
],
Tasks.table_question_answering: [OutputKeys.OUTPUT],


# ============ audio tasks =================== # ============ audio tasks ===================
# asr result for single sample # asr result for single sample


+ 0
- 3
modelscope/pipelines/builder.py View File

@@ -69,9 +69,6 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/nlp_space_dialog-modeling'), 'damo/nlp_space_dialog-modeling'),
Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking,
'damo/nlp_space_dialog-state-tracking'), 'damo/nlp_space_dialog-state-tracking'),
Tasks.conversational_text_to_sql:
(Pipelines.conversational_text_to_sql,
'damo/nlp_star_conversational-text-to-sql'),
Tasks.table_question_answering: Tasks.table_question_answering:
(Pipelines.table_question_answering_pipeline, (Pipelines.table_question_answering_pipeline,
'damo/nlp-convai-text2sql-pretrain-cn'), 'damo/nlp-convai-text2sql-pretrain-cn'),


+ 2
- 2
modelscope/pipelines/nlp/conversational_text_to_sql_pipeline.py View File

@@ -19,7 +19,7 @@ __all__ = ['ConversationalTextToSqlPipeline']




@PIPELINES.register_module( @PIPELINES.register_module(
Tasks.conversational_text_to_sql,
Tasks.table_question_answering,
module_name=Pipelines.conversational_text_to_sql) module_name=Pipelines.conversational_text_to_sql)
class ConversationalTextToSqlPipeline(Pipeline): class ConversationalTextToSqlPipeline(Pipeline):


@@ -62,7 +62,7 @@ class ConversationalTextToSqlPipeline(Pipeline):
Dict[str, str]: the prediction results Dict[str, str]: the prediction results
""" """
sql = Example.evaluator.obtain_sql(inputs['predict'][0], inputs['db']) sql = Example.evaluator.obtain_sql(inputs['predict'][0], inputs['db'])
result = {OutputKeys.TEXT: sql}
result = {OutputKeys.OUTPUT: {OutputKeys.TEXT: sql}}
return result return result


def _collate_fn(self, data): def _collate_fn(self, data):


+ 4
- 3
modelscope/pipelines/nlp/table_question_answering_pipeline.py View File

@@ -13,8 +13,9 @@ from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import TableQuestionAnsweringPreprocessor from modelscope.preprocessors import TableQuestionAnsweringPreprocessor
from modelscope.preprocessors.star3.fields.database import Database
from modelscope.preprocessors.star3.fields.struct import Constant, SQLQuery
from modelscope.preprocessors.space_T_cn.fields.database import Database
from modelscope.preprocessors.space_T_cn.fields.struct import (Constant,
SQLQuery)
from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.constant import ModelFile, Tasks


__all__ = ['TableQuestionAnsweringPipeline'] __all__ = ['TableQuestionAnsweringPipeline']
@@ -320,7 +321,7 @@ class TableQuestionAnsweringPipeline(Pipeline):
OutputKeys.QUERT_RESULT: tabledata, OutputKeys.QUERT_RESULT: tabledata,
} }


return output
return {OutputKeys.OUTPUT: output}


def _collate_fn(self, data): def _collate_fn(self, data):
return data return data

+ 2
- 2
modelscope/preprocessors/__init__.py View File

@@ -40,7 +40,7 @@ if TYPE_CHECKING:
DialogStateTrackingPreprocessor) DialogStateTrackingPreprocessor)
from .video import ReadVideoData, MovieSceneSegmentationPreprocessor from .video import ReadVideoData, MovieSceneSegmentationPreprocessor
from .star import ConversationalTextToSqlPreprocessor from .star import ConversationalTextToSqlPreprocessor
from .star3 import TableQuestionAnsweringPreprocessor
from .space_T_cn import TableQuestionAnsweringPreprocessor


else: else:
_import_structure = { _import_structure = {
@@ -81,7 +81,7 @@ else:
'DialogStateTrackingPreprocessor', 'InputFeatures' 'DialogStateTrackingPreprocessor', 'InputFeatures'
], ],
'star': ['ConversationalTextToSqlPreprocessor'], 'star': ['ConversationalTextToSqlPreprocessor'],
'star3': ['TableQuestionAnsweringPreprocessor'],
'space_T_cn': ['TableQuestionAnsweringPreprocessor'],
} }


import sys import sys


modelscope/preprocessors/star3/__init__.py → modelscope/preprocessors/space_T_cn/__init__.py View File


modelscope/preprocessors/star3/fields/__init__.py → modelscope/preprocessors/space_T_cn/fields/__init__.py View File


modelscope/preprocessors/star3/fields/database.py → modelscope/preprocessors/space_T_cn/fields/database.py View File

@@ -4,7 +4,7 @@ import sqlite3
import json import json
import tqdm import tqdm


from modelscope.preprocessors.star3.fields.struct import Trie
from modelscope.preprocessors.space_T_cn.fields.struct import Trie




class Database: class Database:

modelscope/preprocessors/star3/fields/schema_link.py → modelscope/preprocessors/space_T_cn/fields/schema_link.py View File

@@ -1,7 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import re import re


from modelscope.preprocessors.star3.fields.struct import TypeInfo
from modelscope.preprocessors.space_T_cn.fields.struct import TypeInfo




class SchemaLinker: class SchemaLinker:

modelscope/preprocessors/star3/fields/struct.py → modelscope/preprocessors/space_T_cn/fields/struct.py View File


modelscope/preprocessors/star3/table_question_answering_preprocessor.py → modelscope/preprocessors/space_T_cn/table_question_answering_preprocessor.py View File

@@ -8,8 +8,8 @@ from transformers import BertTokenizer
from modelscope.metainfo import Preprocessors from modelscope.metainfo import Preprocessors
from modelscope.preprocessors.base import Preprocessor from modelscope.preprocessors.base import Preprocessor
from modelscope.preprocessors.builder import PREPROCESSORS from modelscope.preprocessors.builder import PREPROCESSORS
from modelscope.preprocessors.star3.fields.database import Database
from modelscope.preprocessors.star3.fields.schema_link import SchemaLinker
from modelscope.preprocessors.space_T_cn.fields.database import Database
from modelscope.preprocessors.space_T_cn.fields.schema_link import SchemaLinker
from modelscope.utils.config import Config from modelscope.utils.config import Config
from modelscope.utils.constant import Fields, ModelFile from modelscope.utils.constant import Fields, ModelFile
from modelscope.utils.type_assert import type_assert from modelscope.utils.type_assert import type_assert

+ 0
- 1
modelscope/utils/constant.py View File

@@ -123,7 +123,6 @@ class NLPTasks(object):
backbone = 'backbone' backbone = 'backbone'
text_error_correction = 'text-error-correction' text_error_correction = 'text-error-correction'
faq_question_answering = 'faq-question-answering' faq_question_answering = 'faq-question-answering'
conversational_text_to_sql = 'conversational-text-to-sql'
information_extraction = 'information-extraction' information_extraction = 'information-extraction'
document_segmentation = 'document-segmentation' document_segmentation = 'document-segmentation'
feature_extraction = 'feature-extraction' feature_extraction = 'feature-extraction'


+ 1
- 1
modelscope/utils/nlp/nlp_utils.py View File

@@ -20,7 +20,7 @@ def text2sql_tracking_and_print_results(
results = p(case) results = p(case)
print({'question': item}) print({'question': item})
print(results) print(results)
last_sql = results['text']
last_sql = results[OutputKeys.OUTPUT][OutputKeys.TEXT]
history.append(item) history.append(item)






+ 1
- 6
tests/pipelines/test_conversational_text_to_sql.py View File

@@ -16,7 +16,7 @@ from modelscope.utils.test_utils import test_level
class ConversationalTextToSql(unittest.TestCase, DemoCompatibilityCheck): class ConversationalTextToSql(unittest.TestCase, DemoCompatibilityCheck):


def setUp(self) -> None: def setUp(self) -> None:
self.task = Tasks.conversational_text_to_sql
self.task = Tasks.table_question_answering
self.model_id = 'damo/nlp_star_conversational-text-to-sql' self.model_id = 'damo/nlp_star_conversational-text-to-sql'


model_id = 'damo/nlp_star_conversational-text-to-sql' model_id = 'damo/nlp_star_conversational-text-to-sql'
@@ -66,11 +66,6 @@ class ConversationalTextToSql(unittest.TestCase, DemoCompatibilityCheck):
pipelines = [pipeline(task=self.task, model=self.model_id)] pipelines = [pipeline(task=self.task, model=self.model_id)]
text2sql_tracking_and_print_results(self.test_case, pipelines) text2sql_tracking_and_print_results(self.test_case, pipelines)


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipelines = [pipeline(task=self.task)]
text2sql_tracking_and_print_results(self.test_case, pipelines)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_demo_compatibility(self): def test_demo_compatibility(self):
self.compatibility_check() self.compatibility_check()


+ 4
- 9
tests/pipelines/test_table_question_answering.py View File

@@ -12,7 +12,7 @@ from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline
from modelscope.preprocessors import TableQuestionAnsweringPreprocessor from modelscope.preprocessors import TableQuestionAnsweringPreprocessor
from modelscope.preprocessors.star3.fields.database import Database
from modelscope.preprocessors.space_T_cn.fields.database import Database
from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.test_utils import test_level from modelscope.utils.test_utils import test_level


@@ -38,7 +38,7 @@ def tableqa_tracking_and_print_results_with_history(
output_dict = p({ output_dict = p({
'question': question, 'question': question,
'history_sql': historical_queries 'history_sql': historical_queries
})
})[OutputKeys.OUTPUT]
print('question', question) print('question', question)
print('sql text:', output_dict[OutputKeys.SQL_STRING]) print('sql text:', output_dict[OutputKeys.SQL_STRING])
print('sql query:', output_dict[OutputKeys.SQL_QUERY]) print('sql query:', output_dict[OutputKeys.SQL_QUERY])
@@ -61,7 +61,7 @@ def tableqa_tracking_and_print_results_without_history(
} }
for p in pipelines: for p in pipelines:
for question in test_case['utterance']: for question in test_case['utterance']:
output_dict = p({'question': question})
output_dict = p({'question': question})[OutputKeys.OUTPUT]
print('question', question) print('question', question)
print('sql text:', output_dict[OutputKeys.SQL_STRING]) print('sql text:', output_dict[OutputKeys.SQL_STRING])
print('sql query:', output_dict[OutputKeys.SQL_QUERY]) print('sql query:', output_dict[OutputKeys.SQL_QUERY])
@@ -92,7 +92,7 @@ def tableqa_tracking_and_print_results_with_tableid(
'question': question, 'question': question,
'table_id': table_id, 'table_id': table_id,
'history_sql': historical_queries 'history_sql': historical_queries
})
})[OutputKeys.OUTPUT]
print('question', question) print('question', question)
print('sql text:', output_dict[OutputKeys.SQL_STRING]) print('sql text:', output_dict[OutputKeys.SQL_STRING])
print('sql query:', output_dict[OutputKeys.SQL_QUERY]) print('sql query:', output_dict[OutputKeys.SQL_QUERY])
@@ -147,11 +147,6 @@ class TableQuestionAnswering(unittest.TestCase):
] ]
tableqa_tracking_and_print_results_with_tableid(pipelines) tableqa_tracking_and_print_results_with_tableid(pipelines)


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_task(self):
pipelines = [pipeline(Tasks.table_question_answering, self.model_id)]
tableqa_tracking_and_print_results_with_history(pipelines)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_model_from_modelhub_with_other_classes(self): def test_run_with_model_from_modelhub_with_other_classes(self):
model = Model.from_pretrained(self.model_id) model = Model.from_pretrained(self.model_id)


Loading…
Cancel
Save