Browse Source

[to #42322933] add conversational_text_to_sql pipeline

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9580066
master
piaoyu.lxy yingda.chen 3 years ago
parent
commit
2dc3286524
19 changed files with 1308 additions and 0 deletions
  1. +3
    -0
      modelscope/metainfo.py
  2. +2
    -0
      modelscope/models/nlp/__init__.py
  3. +68
    -0
      modelscope/models/nlp/star_text_to_sql.py
  4. +6
    -0
      modelscope/outputs.py
  5. +3
    -0
      modelscope/pipelines/base.py
  6. +5
    -0
      modelscope/pipelines/builder.py
  7. +3
    -0
      modelscope/pipelines/nlp/__init__.py
  8. +66
    -0
      modelscope/pipelines/nlp/conversational_text_to_sql_pipeline.py
  9. +2
    -0
      modelscope/preprocessors/__init__.py
  10. +29
    -0
      modelscope/preprocessors/star/__init__.py
  11. +111
    -0
      modelscope/preprocessors/star/conversational_text_to_sql_preprocessor.py
  12. +6
    -0
      modelscope/preprocessors/star/fields/__init__.py
  13. +471
    -0
      modelscope/preprocessors/star/fields/common_utils.py
  14. +333
    -0
      modelscope/preprocessors/star/fields/parse.py
  15. +37
    -0
      modelscope/preprocessors/star/fields/preprocess_dataset.py
  16. +64
    -0
      modelscope/preprocessors/star/fields/process_dataset.py
  17. +1
    -0
      modelscope/utils/constant.py
  18. +1
    -0
      requirements/nlp.txt
  19. +97
    -0
      tests/pipelines/test_conversational_text_to_sql.py

+ 3
- 0
modelscope/metainfo.py View File

@@ -29,6 +29,7 @@ class Models(object):
space_dst = 'space-dst'
space_intent = 'space-intent'
space_modeling = 'space-modeling'
star = 'star'
tcrf = 'transformer-crf'
bart = 'bart'
gpt3 = 'gpt3'
@@ -123,6 +124,7 @@ class Pipelines(object):
dialog_state_tracking = 'dialog-state-tracking'
zero_shot_classification = 'zero-shot-classification'
text_error_correction = 'text-error-correction'
conversational_text_to_sql = 'conversational-text-to-sql'

# audio tasks
sambert_hifigan_tts = 'sambert-hifigan-tts'
@@ -201,6 +203,7 @@ class Preprocessors(object):
text_error_correction = 'text-error-correction'
word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor'
fill_mask = 'fill-mask'
conversational_text_to_sql = 'conversational-text-to-sql'

# audio preprocessor
linear_aec_fbank = 'linear-aec-fbank'


+ 2
- 0
modelscope/models/nlp/__init__.py View File

@@ -17,12 +17,14 @@ if TYPE_CHECKING:
from .space import SpaceForDialogIntent
from .space import SpaceForDialogModeling
from .space import SpaceForDialogStateTracking
from .star_text_to_sql import StarForTextToSql
from .task_models.task_model import SingleBackboneTaskModelBase
from .bart_for_text_error_correction import BartForTextErrorCorrection
from .gpt3 import GPT3ForTextGeneration

else:
_import_structure = {
'star_text_to_sql': ['StarForTextToSql'],
'backbones': ['SbertModel'],
'heads': ['SequenceClassificationHead'],
'csanmt_for_translation': ['CsanmtForTranslation'],


+ 68
- 0
modelscope/models/nlp/star_text_to_sql.py View File

@@ -0,0 +1,68 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
from typing import Dict, Optional

import torch
import torch.nn as nn
from text2sql_lgesql.asdl.asdl import ASDLGrammar
from text2sql_lgesql.asdl.transition_system import TransitionSystem
from text2sql_lgesql.model.model_constructor import Text2SQL
from text2sql_lgesql.utils.constants import GRAMMAR_FILEPATH

from modelscope.metainfo import Models
from modelscope.models.base import Model, Tensor
from modelscope.models.builder import MODELS
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks

__all__ = ['StarForTextToSql']


@MODELS.register_module(
Tasks.conversational_text_to_sql, module_name=Models.star)
class StarForTextToSql(Model):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the star model from the `model_dir` path.

Args:
model_dir (str): the model path.
"""
super().__init__(model_dir, *args, **kwargs)
self.beam_size = 5
self.config = kwargs.pop(
'config',
Config.from_file(
os.path.join(self.model_dir, ModelFile.CONFIGURATION)))
self.config.model.model_dir = model_dir
self.grammar = ASDLGrammar.from_filepath(
os.path.join(model_dir, 'sql_asdl_v2.txt'))
self.trans = TransitionSystem.get_class_by_lang('sql')(self.grammar)
self.arg = self.config.model
self.device = 'cuda' if \
('device' not in kwargs or kwargs['device'] == 'gpu') \
and torch.cuda.is_available() else 'cpu'
self.model = Text2SQL(self.arg, self.trans)
check_point = torch.load(
open(
os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), 'rb'),
map_location=self.device)
self.model.load_state_dict(check_point['model'])

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""return the result by the model

Args:
input (Dict[str, Tensor]): the preprocessed data

Returns:
Dict[str, Tensor]: results
Example:
"""
self.model.eval()
hyps = self.model.parse(input['batch'], self.beam_size) #
db = input['batch'].examples[0].db

predict = {'predict': hyps, 'db': db}
return predict

+ 6
- 0
modelscope/outputs.py View File

@@ -389,6 +389,12 @@ TASK_OUTPUTS = {
# }
Tasks.task_oriented_conversation: [OutputKeys.OUTPUT],

# conversational text-to-sql result for single sample
# {
# "text": "SELECT shop.Name FROM shop."
# }
Tasks.conversational_text_to_sql: [OutputKeys.TEXT],

# ============ audio tasks ===================
# asr result for single sample
# { "text": "每一天都要快乐喔"}


+ 3
- 0
modelscope/pipelines/base.py View File

@@ -239,6 +239,7 @@ class Pipeline(ABC):
"""
from torch.utils.data.dataloader import default_collate
from modelscope.preprocessors import InputFeatures
from text2sql_lgesql.utils.batch import Batch
if isinstance(data, dict) or isinstance(data, Mapping):
return type(data)(
{k: self._collate_fn(v)
@@ -259,6 +260,8 @@ class Pipeline(ABC):
return data
elif isinstance(data, InputFeatures):
return data
elif isinstance(data, Batch):
return data
else:
import mmcv
if isinstance(data, mmcv.parallel.data_container.DataContainer):


+ 5
- 0
modelscope/pipelines/builder.py View File

@@ -50,6 +50,11 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/nlp_structbert_zero-shot-classification_chinese-base'),
Tasks.task_oriented_conversation: (Pipelines.dialog_modeling,
'damo/nlp_space_dialog-modeling'),
Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking,
'damo/nlp_space_dialog-state-tracking'),
Tasks.conversational_text_to_sql:
(Pipelines.conversational_text_to_sql,
'damo/nlp_star_conversational-text-to-sql'),
Tasks.text_error_correction:
(Pipelines.text_error_correction,
'damo/nlp_bart_text-error-correction_chinese'),


+ 3
- 0
modelscope/pipelines/nlp/__init__.py View File

@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .conversational_text_to_sql_pipeline import ConversationalTextToSqlPipeline
from .dialog_intent_prediction_pipeline import DialogIntentPredictionPipeline
from .dialog_modeling_pipeline import DialogModelingPipeline
from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline
@@ -22,6 +23,8 @@ if TYPE_CHECKING:

else:
_import_structure = {
'conversational_text_to_sql_pipeline':
['ConversationalTextToSqlPipeline'],
'dialog_intent_prediction_pipeline':
['DialogIntentPredictionPipeline'],
'dialog_modeling_pipeline': ['DialogModelingPipeline'],


+ 66
- 0
modelscope/pipelines/nlp/conversational_text_to_sql_pipeline.py View File

@@ -0,0 +1,66 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, Union

import torch
from text2sql_lgesql.utils.example import Example

from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.models.nlp import StarForTextToSql
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import ConversationalTextToSqlPreprocessor
from modelscope.preprocessors.star.fields.common_utils import SubPreprocessor
from modelscope.preprocessors.star.fields.process_dataset import process_tables
from modelscope.utils.constant import Tasks

__all__ = ['ConversationalTextToSqlPipeline']


@PIPELINES.register_module(
Tasks.conversational_text_to_sql,
module_name=Pipelines.conversational_text_to_sql)
class ConversationalTextToSqlPipeline(Pipeline):

def __init__(self,
model: Union[StarForTextToSql, str],
preprocessor: ConversationalTextToSqlPreprocessor = None,
**kwargs):
"""use `model` and `preprocessor` to create a conversational text-to-sql prediction pipeline

Args:
model (StarForTextToSql): a model instance
preprocessor (ConversationalTextToSqlPreprocessor):
a preprocessor instance
"""
model = model if isinstance(
model, StarForTextToSql) else Model.from_pretrained(model)
if preprocessor is None:
preprocessor = ConversationalTextToSqlPreprocessor(model.model_dir)

preprocessor.device = 'cuda' if \
('device' not in kwargs or kwargs['device'] == 'gpu') \
and torch.cuda.is_available() else 'cpu'
use_device = True if preprocessor.device == 'cuda' else False
preprocessor.processor = \
SubPreprocessor(model_dir=model.model_dir,
db_content=True,
use_gpu=use_device)
preprocessor.output_tables = \
process_tables(preprocessor.processor,
preprocessor.tables)
super().__init__(model=model, preprocessor=preprocessor, **kwargs)

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""process the prediction results

Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, str]: the prediction results
"""
sql = Example.evaluator.obtain_sql(inputs['predict'][0], inputs['db'])
result = {OutputKeys.TEXT: sql}
return result

+ 2
- 0
modelscope/preprocessors/__init__.py View File

@@ -27,6 +27,7 @@ if TYPE_CHECKING:
DialogModelingPreprocessor,
DialogStateTrackingPreprocessor)
from .video import ReadVideoData
from .star import ConversationalTextToSqlPreprocessor

else:
_import_structure = {
@@ -55,6 +56,7 @@ else:
'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor',
'DialogStateTrackingPreprocessor', 'InputFeatures'
],
'star': ['ConversationalTextToSqlPreprocessor'],
}

import sys


+ 29
- 0
modelscope/preprocessors/star/__init__.py View File

@@ -0,0 +1,29 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .conversational_text_to_sql_preprocessor import \
ConversationalTextToSqlPreprocessor
from .fields import MultiWOZBPETextField, IntentBPETextField

else:
_import_structure = {
'conversational_text_to_sql_preprocessor':
['ConversationalTextToSqlPreprocessor'],
'fields': [
'get_label', 'SubPreprocessor', 'preprocess_dataset',
'process_dataset'
]
}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 111
- 0
modelscope/preprocessors/star/conversational_text_to_sql_preprocessor.py View File

@@ -0,0 +1,111 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import Any, Dict

import json
import torch
from text2sql_lgesql.preprocess.graph_utils import GraphProcessor
from text2sql_lgesql.preprocess.process_graphs import process_dataset_graph
from text2sql_lgesql.utils.batch import Batch
from text2sql_lgesql.utils.example import Example

from modelscope.metainfo import Preprocessors
from modelscope.preprocessors.base import Preprocessor
from modelscope.preprocessors.builder import PREPROCESSORS
from modelscope.preprocessors.star.fields.preprocess_dataset import \
preprocess_dataset
from modelscope.preprocessors.star.fields.process_dataset import (
process_dataset, process_tables)
from modelscope.utils.config import Config
from modelscope.utils.constant import Fields, ModelFile
from modelscope.utils.type_assert import type_assert

__all__ = ['ConversationalTextToSqlPreprocessor']


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.conversational_text_to_sql)
class ConversationalTextToSqlPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
model_dir (str): model path
"""
super().__init__(*args, **kwargs)

self.model_dir: str = model_dir

self.config = Config.from_file(
os.path.join(self.model_dir, ModelFile.CONFIGURATION))
self.device = 'cuda' if \
('device' not in kwargs or kwargs['device'] == 'gpu') \
and torch.cuda.is_available() else 'cpu'
self.processor = None
self.table_path = os.path.join(self.model_dir, 'tables.json')
self.tables = json.load(open(self.table_path, 'r'))
self.output_tables = None
self.path_cache = []
self.graph_processor = GraphProcessor()

Example.configuration(
plm=self.config['model']['plm'],
tables=self.output_tables,
table_path=os.path.join(model_dir, 'tables.json'),
model_dir=self.model_dir,
db_dir=os.path.join(model_dir, 'db'))

@type_assert(object, dict)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""process the raw input data

Args:
data (dict):
utterance: a sentence
last_sql: predicted sql of last utterance
Example:
utterance: 'Which of these are hiring?'
last_sql: ''

Returns:
Dict[str, Any]: the preprocessed data
"""
# use local database
if data['local_db_path'] is not None and data[
'local_db_path'] not in self.path_cache:
self.path_cache.append(data['local_db_path'])
path = os.path.join(data['local_db_path'], 'tables.json')
self.tables = json.load(open(path, 'r'))
self.processor.db_dir = os.path.join(data['local_db_path'], 'db')
self.output_tables = process_tables(self.processor, self.tables)
Example.configuration(
plm=self.config['model']['plm'],
tables=self.output_tables,
table_path=path,
model_dir=self.model_dir,
db_dir=self.processor.db_dir)

theresult, sql_label = \
preprocess_dataset(
self.processor,
data,
self.output_tables,
data['database_id'],
self.tables
)
output_dataset = process_dataset(self.model_dir, self.processor,
theresult, self.output_tables)
output_dataset = \
process_dataset_graph(
self.graph_processor,
output_dataset,
self.output_tables,
method='lgesql'
)
dev_ex = Example(output_dataset[0],
self.output_tables[data['database_id']], sql_label)
current_batch = Batch.from_example_list([dev_ex],
self.device,
train=False)
return {'batch': current_batch, 'db': data['database_id']}

+ 6
- 0
modelscope/preprocessors/star/fields/__init__.py View File

@@ -0,0 +1,6 @@
from modelscope.preprocessors.star.fields.common_utils import SubPreprocessor
from modelscope.preprocessors.star.fields.parse import get_label
from modelscope.preprocessors.star.fields.preprocess_dataset import \
preprocess_dataset
from modelscope.preprocessors.star.fields.process_dataset import \
process_dataset

+ 471
- 0
modelscope/preprocessors/star/fields/common_utils.py View File

@@ -0,0 +1,471 @@
# Copyright (c) rhythmcao modified from https://github.com/rhythmcao/text2sql-lgesql.

import os
import sqlite3
from itertools import combinations, product

import nltk
import numpy as np
from text2sql_lgesql.utils.constants import MAX_RELATIVE_DIST

from modelscope.utils.logger import get_logger

mwtokenizer = nltk.MWETokenizer(separator='')
mwtokenizer.add_mwe(('[', 'CLS', ']'))
logger = get_logger()


def is_number(s):
try:
float(s)
return True
except ValueError:
return False


def quote_normalization(question):
""" Normalize all usage of quotation marks into a separate \" """
new_question, quotation_marks = [], [
"'", '"', '`', '‘', '’', '“', '”', '``', "''", '‘‘', '’’'
]
for idx, tok in enumerate(question):
if len(tok) > 2 and tok[0] in quotation_marks and tok[
-1] in quotation_marks:
new_question += ["\"", tok[1:-1], "\""]
elif len(tok) > 2 and tok[0] in quotation_marks:
new_question += ["\"", tok[1:]]
elif len(tok) > 2 and tok[-1] in quotation_marks:
new_question += [tok[:-1], "\""]
elif tok in quotation_marks:
new_question.append("\"")
elif len(tok) == 2 and tok[0] in quotation_marks:
# special case: the length of entity value is 1
if idx + 1 < len(question) and question[idx
+ 1] in quotation_marks:
new_question += ["\"", tok[1]]
else:
new_question.append(tok)
else:
new_question.append(tok)
return new_question


class SubPreprocessor():

def __init__(self, model_dir, use_gpu=False, db_content=True):
super(SubPreprocessor, self).__init__()
self.model_dir = model_dir
self.db_dir = os.path.join(model_dir, 'db')
self.db_content = db_content

from nltk import data
from nltk.corpus import stopwords
data.path.append(os.path.join(self.model_dir, 'nltk_data'))
self.stopwords = stopwords.words('english')

import stanza
from stanza.resources import common
from stanza.pipeline import core
self.nlp = stanza.Pipeline(
'en',
use_gpu=use_gpu,
dir=self.model_dir,
processors='tokenize,pos,lemma',
tokenize_pretokenized=True,
download_method=core.DownloadMethod.REUSE_RESOURCES)
self.nlp1 = stanza.Pipeline(
'en',
use_gpu=use_gpu,
dir=self.model_dir,
processors='tokenize,pos,lemma',
download_method=core.DownloadMethod.REUSE_RESOURCES)

def pipeline(self, entry: dict, db: dict, verbose: bool = False):
""" db should be preprocessed """
entry = self.preprocess_question(entry, db, verbose=verbose)
entry = self.schema_linking(entry, db, verbose=verbose)
entry = self.extract_subgraph(entry, db, verbose=verbose)
return entry

def preprocess_database(self, db: dict, verbose: bool = False):
table_toks, table_names = [], []
for tab in db['table_names']:
doc = self.nlp1(tab)
tab = [w.lemma.lower() for s in doc.sentences for w in s.words]
table_toks.append(tab)
table_names.append(' '.join(tab))
db['processed_table_toks'], db[
'processed_table_names'] = table_toks, table_names
column_toks, column_names = [], []
for _, c in db['column_names']:
doc = self.nlp1(c)
c = [w.lemma.lower() for s in doc.sentences for w in s.words]
column_toks.append(c)
column_names.append(' '.join(c))
db['processed_column_toks'], db[
'processed_column_names'] = column_toks, column_names
column2table = list(map(lambda x: x[0], db['column_names']))
table2columns = [[] for _ in range(len(table_names))]
for col_id, col in enumerate(db['column_names']):
if col_id == 0:
continue
table2columns[col[0]].append(col_id)
db['column2table'], db['table2columns'] = column2table, table2columns

t_num, c_num, dtype = len(db['table_names']), len(
db['column_names']), '<U100'

tab_mat = np.array([['table-table-generic'] * t_num
for _ in range(t_num)],
dtype=dtype)
table_fks = set(
map(lambda pair: (column2table[pair[0]], column2table[pair[1]]),
db['foreign_keys']))
for (tab1, tab2) in table_fks:
if (tab2, tab1) in table_fks:
tab_mat[tab1, tab2], tab_mat[
tab2, tab1] = 'table-table-fkb', 'table-table-fkb'
else:
tab_mat[tab1, tab2], tab_mat[
tab2, tab1] = 'table-table-fk', 'table-table-fkr'
tab_mat[list(range(t_num)),
list(range(t_num))] = 'table-table-identity'

col_mat = np.array([['column-column-generic'] * c_num
for _ in range(c_num)],
dtype=dtype)
for i in range(t_num):
col_ids = [idx for idx, t in enumerate(column2table) if t == i]
col1, col2 = list(zip(*list(product(col_ids, col_ids))))
col_mat[col1, col2] = 'column-column-sametable'
col_mat[list(range(c_num)),
list(range(c_num))] = 'column-column-identity'
if len(db['foreign_keys']) > 0:
col1, col2 = list(zip(*db['foreign_keys']))
col_mat[col1, col2], col_mat[
col2, col1] = 'column-column-fk', 'column-column-fkr'
col_mat[0, list(range(c_num))] = '*-column-generic'
col_mat[list(range(c_num)), 0] = 'column-*-generic'
col_mat[0, 0] = '*-*-identity'

# relations between tables and columns, t_num*c_num and c_num*t_num
tab_col_mat = np.array([['table-column-generic'] * c_num
for _ in range(t_num)],
dtype=dtype)
col_tab_mat = np.array([['column-table-generic'] * t_num
for _ in range(c_num)],
dtype=dtype)
cols, tabs = list(
zip(*list(map(lambda x: (x, column2table[x]), range(1, c_num)))))
col_tab_mat[cols, tabs], tab_col_mat[
tabs, cols] = 'column-table-has', 'table-column-has'
if len(db['primary_keys']) > 0:
cols, tabs = list(
zip(*list(
map(lambda x: (x, column2table[x]), db['primary_keys']))))
col_tab_mat[cols, tabs], tab_col_mat[
tabs, cols] = 'column-table-pk', 'table-column-pk'
col_tab_mat[0, list(range(t_num))] = '*-table-generic'
tab_col_mat[list(range(t_num)), 0] = 'table-*-generic'

relations = \
np.concatenate([
np.concatenate([tab_mat, tab_col_mat], axis=1),
np.concatenate([col_tab_mat, col_mat], axis=1)
], axis=0)
db['relations'] = relations.tolist()

if verbose:
print('Tables:', ', '.join(db['table_names']))
print('Lemmatized:', ', '.join(table_names))
print('Columns:',
', '.join(list(map(lambda x: x[1], db['column_names']))))
print('Lemmatized:', ', '.join(column_names), '\n')
return db

def preprocess_question(self,
entry: dict,
db: dict,
verbose: bool = False):
""" Tokenize, lemmatize, lowercase question"""
# stanza tokenize, lemmatize and POS tag
question = ' '.join(quote_normalization(entry['question_toks']))

from nltk import data
data.path.append(os.path.join(self.model_dir, 'nltk_data'))
question = nltk.word_tokenize(question)
question = mwtokenizer.tokenize(question)

doc = self.nlp([question])
raw_toks = [w.text.lower() for s in doc.sentences for w in s.words]
toks = [w.lemma.lower() for s in doc.sentences for w in s.words]
pos_tags = [w.xpos for s in doc.sentences for w in s.words]

entry['raw_question_toks'] = raw_toks
entry['processed_question_toks'] = toks
entry['pos_tags'] = pos_tags

q_num, dtype = len(toks), '<U100'
if q_num <= MAX_RELATIVE_DIST + 1:
dist_vec = [
'question-question-dist'
+ str(i) if i != 0 else 'question-question-identity'
for i in range(-MAX_RELATIVE_DIST, MAX_RELATIVE_DIST + 1, 1)
]
starting = MAX_RELATIVE_DIST
else:
dist_vec = ['question-question-generic'] \
* (q_num - MAX_RELATIVE_DIST - 1) + \
[
'question-question-dist' + str(i)
if i != 0 else 'question-question-identity'
for i in range(- MAX_RELATIVE_DIST, MAX_RELATIVE_DIST + 1,
1)]\
+ ['question-question-generic'] \
* (q_num - MAX_RELATIVE_DIST - 1)
starting = q_num - 1
list_data = \
[dist_vec[starting - i:starting - i + q_num] for i in range(q_num)]
q_mat = \
np.array(
list_data,
dtype=dtype
)
entry['relations'] = q_mat.tolist()

if verbose:
print('Question:', entry['question'])
print('Tokenized:', ' '.join(entry['raw_question_toks']))
print('Lemmatized:', ' '.join(entry['processed_question_toks']))
print('Pos tags:', ' '.join(entry['pos_tags']), '\n')
return entry

def extract_subgraph(self, entry: dict, db: dict, verbose: bool = False):
used_schema = {'table': set(), 'column': set()}
entry['used_tables'] = sorted(list(used_schema['table']))
entry['used_columns'] = sorted(list(used_schema['column']))

if verbose:
print('Used tables:', entry['used_tables'])
print('Used columns:', entry['used_columns'], '\n')
return entry

def extract_subgraph_from_sql(self, sql: dict, used_schema: dict):
select_items = sql['select'][1]
# select clause
for _, val_unit in select_items:
if val_unit[0] == 0:
col_unit = val_unit[1]
used_schema['column'].add(col_unit[1])
else:
col_unit1, col_unit2 = val_unit[1:]
used_schema['column'].add(col_unit1[1])
used_schema['column'].add(col_unit2[1])
# from clause conds
table_units = sql['from']['table_units']
for _, t in table_units:
if type(t) == dict:
used_schema = self.extract_subgraph_from_sql(t, used_schema)
else:
used_schema['table'].add(t)
# from, where and having conds
used_schema = self.extract_subgraph_from_conds(sql['from']['conds'],
used_schema)
used_schema = self.extract_subgraph_from_conds(sql['where'],
used_schema)
used_schema = self.extract_subgraph_from_conds(sql['having'],
used_schema)
# groupBy and orderBy clause
groupBy = sql['groupBy']
for col_unit in groupBy:
used_schema['column'].add(col_unit[1])
orderBy = sql['orderBy']
if len(orderBy) > 0:
orderBy = orderBy[1]
for val_unit in orderBy:
if val_unit[0] == 0:
col_unit = val_unit[1]
used_schema['column'].add(col_unit[1])
else:
col_unit1, col_unit2 = val_unit[1:]
used_schema['column'].add(col_unit1[1])
used_schema['column'].add(col_unit2[1])
# union, intersect and except clause
if sql['intersect']:
used_schema = self.extract_subgraph_from_sql(
sql['intersect'], used_schema)
if sql['union']:
used_schema = self.extract_subgraph_from_sql(
sql['union'], used_schema)
if sql['except']:
used_schema = self.extract_subgraph_from_sql(
sql['except'], used_schema)
return used_schema

def extract_subgraph_from_conds(self, conds: list, used_schema: dict):
if len(conds) == 0:
return used_schema
for cond in conds:
if cond in ['and', 'or']:
continue
val_unit, val1, val2 = cond[2:]
if val_unit[0] == 0:
col_unit = val_unit[1]
used_schema['column'].add(col_unit[1])
else:
col_unit1, col_unit2 = val_unit[1:]
used_schema['column'].add(col_unit1[1])
used_schema['column'].add(col_unit2[1])
if type(val1) == list:
used_schema['column'].add(val1[1])
elif type(val1) == dict:
used_schema = self.extract_subgraph_from_sql(val1, used_schema)
if type(val2) == list:
used_schema['column'].add(val1[1])
elif type(val2) == dict:
used_schema = self.extract_subgraph_from_sql(val2, used_schema)
return used_schema

def schema_linking(self, entry: dict, db: dict, verbose: bool = False):
raw_question_toks, question_toks = entry['raw_question_toks'], entry[
'processed_question_toks']
table_toks, column_toks = db['processed_table_toks'], db[
'processed_column_toks']
table_names, column_names = db['processed_table_names'], db[
'processed_column_names']
q_num, t_num, c_num, dtype = len(question_toks), len(table_toks), len(
column_toks), '<U100'

# relations between questions and tables, q_num*t_num and t_num*q_num
table_matched_pairs = {'partial': [], 'exact': []}
q_tab_mat = np.array([['question-table-nomatch'] * t_num
for _ in range(q_num)],
dtype=dtype)
tab_q_mat = np.array([['table-question-nomatch'] * q_num
for _ in range(t_num)],
dtype=dtype)
max_len = max([len(t) for t in table_toks])
index_pairs = list(
filter(lambda x: x[1] - x[0] <= max_len,
combinations(range(q_num + 1), 2)))
index_pairs = sorted(index_pairs, key=lambda x: x[1] - x[0])
for i, j in index_pairs:
phrase = ' '.join(question_toks[i:j])
if phrase in self.stopwords:
continue
for idx, name in enumerate(table_names):
if phrase == name:
q_tab_mat[range(i, j), idx] = 'question-table-exactmatch'
tab_q_mat[idx, range(i, j)] = 'table-question-exactmatch'
if verbose:
table_matched_pairs['exact'].append(
str((name, idx, phrase, i, j)))
elif (j - i == 1
and phrase in name.split()) or (j - i > 1
and phrase in name):
q_tab_mat[range(i, j), idx] = 'question-table-partialmatch'
tab_q_mat[idx, range(i, j)] = 'table-question-partialmatch'
if verbose:
table_matched_pairs['partial'].append(
str((name, idx, phrase, i, j)))

# relations between questions and columns
column_matched_pairs = {'partial': [], 'exact': [], 'value': []}
q_col_mat = np.array([['question-column-nomatch'] * c_num
for _ in range(q_num)],
dtype=dtype)
col_q_mat = np.array([['column-question-nomatch'] * q_num
for _ in range(c_num)],
dtype=dtype)
max_len = max([len(c) for c in column_toks])
index_pairs = list(
filter(lambda x: x[1] - x[0] <= max_len,
combinations(range(q_num + 1), 2)))
index_pairs = sorted(index_pairs, key=lambda x: x[1] - x[0])
for i, j in index_pairs:
phrase = ' '.join(question_toks[i:j])
if phrase in self.stopwords:
continue
for idx, name in enumerate(column_names):
if phrase == name:
q_col_mat[range(i, j), idx] = 'question-column-exactmatch'
col_q_mat[idx, range(i, j)] = 'column-question-exactmatch'
if verbose:
column_matched_pairs['exact'].append(
str((name, idx, phrase, i, j)))
elif (j - i == 1
and phrase in name.split()) or (j - i > 1
and phrase in name):
q_col_mat[range(i, j),
idx] = 'question-column-partialmatch'
col_q_mat[idx,
range(i, j)] = 'column-question-partialmatch'
if verbose:
column_matched_pairs['partial'].append(
str((name, idx, phrase, i, j)))
if self.db_content:
db_file = os.path.join(self.db_dir, db['db_id'],
db['db_id'] + '.sqlite')
if not os.path.exists(db_file):
raise ValueError('[ERROR]: database file %s not found ...' %
(db_file))
conn = sqlite3.connect(db_file)
conn.text_factory = lambda b: b.decode(errors='ignore')
conn.execute('pragma foreign_keys=ON')
for i, (tab_id,
col_name) in enumerate(db['column_names_original']):
if i == 0 or 'id' in column_toks[
i]: # ignore * and special token 'id'
continue
tab_name = db['table_names_original'][tab_id]
try:
cursor = conn.execute("SELECT DISTINCT \"%s\" FROM \"%s\";"
% (col_name, tab_name))
cell_values = cursor.fetchall()
cell_values = [str(each[0]) for each in cell_values]
cell_values = [[str(float(each))] if is_number(each) else
each.lower().split()
for each in cell_values]
except Exception as e:
print(e)
for j, word in enumerate(raw_question_toks):
word = str(float(word)) if is_number(word) else word
for c in cell_values:
if word in c and 'nomatch' in q_col_mat[
j, i] and word not in self.stopwords:
q_col_mat[j, i] = 'question-column-valuematch'
col_q_mat[i, j] = 'column-question-valuematch'
if verbose:
column_matched_pairs['value'].append(
str((column_names[i], i, word, j, j + 1)))
break
conn.close()

q_col_mat[:, 0] = 'question-*-generic'
col_q_mat[0] = '*-question-generic'
q_schema = np.concatenate([q_tab_mat, q_col_mat], axis=1)
schema_q = np.concatenate([tab_q_mat, col_q_mat], axis=0)
entry['schema_linking'] = (q_schema.tolist(), schema_q.tolist())

if verbose:
print('Question:', ' '.join(question_toks))
print('Table matched: (table name, column id, \
question span, start id, end id)')
print(
'Exact match:', ', '.join(table_matched_pairs['exact'])
if table_matched_pairs['exact'] else 'empty')
print(
'Partial match:', ', '.join(table_matched_pairs['partial'])
if table_matched_pairs['partial'] else 'empty')
print('Column matched: (column name, column id, \
question span, start id, end id)')
print(
'Exact match:', ', '.join(column_matched_pairs['exact'])
if column_matched_pairs['exact'] else 'empty')
print(
'Partial match:', ', '.join(column_matched_pairs['partial'])
if column_matched_pairs['partial'] else 'empty')
print(
'Value match:', ', '.join(column_matched_pairs['value'])
if column_matched_pairs['value'] else 'empty', '\n')
return entry

+ 333
- 0
modelscope/preprocessors/star/fields/parse.py View File

@@ -0,0 +1,333 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

CLAUSE_KEYWORDS = ('SELECT', 'FROM', 'WHERE', 'GROUP', 'ORDER', 'LIMIT',
'INTERSECT', 'UNION', 'EXCEPT')
JOIN_KEYWORDS = ('JOIN', 'ON', 'AS')

WHERE_OPS = ('NOT_IN', 'BETWEEN', '=', '>', '<', '>=', '<=', '!=', 'IN',
'LIKE', 'IS', 'EXISTS')
UNIT_OPS = ('NONE', '-', '+', '*', '/')
AGG_OPS = ('', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG')
TABLE_TYPE = {
'sql': 'sql',
'table_unit': 'table_unit',
}
COND_OPS = ('AND', 'OR')
SQL_OPS = ('INTERSECT', 'UNION', 'EXCEPT')
ORDER_OPS = ('DESC', 'ASC')


def get_select_labels(select, slot, cur_nest):
for item in select[1]:
if AGG_OPS[item[0]] != '':
if slot[item[1][1][1]] == '':
slot[item[1][1][1]] += (cur_nest + ' ' + AGG_OPS[item[0]])
else:
slot[item[1][1][1]] += (' ' + cur_nest + ' '
+ AGG_OPS[item[0]])
else:
if slot[item[1][1][1]] == '':
slot[item[1][1][1]] += (cur_nest)
else:
slot[item[1][1][1]] += (' ' + cur_nest)
return slot


def get_groupby_labels(groupby, slot, cur_nest):
for item in groupby:
if slot[item[1]] == '':
slot[item[1]] += (cur_nest)
else:
slot[item[1]] += (' ' + cur_nest)
return slot


def get_orderby_labels(orderby, limit, slot, cur_nest):
if limit is None:
thelimit = ''
else:
thelimit = ' LIMIT'
for item in orderby[1]:
if AGG_OPS[item[1][0]] != '':
agg = ' ' + AGG_OPS[item[1][0]] + ' '
else:
agg = ' '
if slot[item[1][1]] == '':
slot[item[1][1]] += (
cur_nest + agg + orderby[0].upper() + thelimit)
else:
slot[item[1][1]] += (' ' + cur_nest + agg + orderby[0].upper()
+ thelimit)

return slot


def get_intersect_labels(intersect, slot, cur_nest):
if isinstance(intersect, dict):
if cur_nest != '':
slot = get_labels(intersect, slot, cur_nest)
else:
slot = get_labels(intersect, slot, 'INTERSECT')
else:
return slot
return slot


def get_except_labels(texcept, slot, cur_nest):
if isinstance(texcept, dict):
if cur_nest != '':
slot = get_labels(texcept, slot, cur_nest)
else:
slot = get_labels(texcept, slot, 'EXCEPT')
else:
return slot
return slot


def get_union_labels(union, slot, cur_nest):
if isinstance(union, dict):
if cur_nest != '':
slot = get_labels(union, slot, cur_nest)
else:
slot = get_labels(union, slot, 'UNION')
else:
return slot
return slot


def get_from_labels(tfrom, slot, cur_nest):
if tfrom['table_units'][0][0] == 'sql':
slot = get_labels(tfrom['table_units'][0][1], slot, 'OP_SEL')
else:
return slot
return slot


def get_having_labels(having, slot, cur_nest):
if len(having) == 1:
item = having[0]
if item[0] is True:
neg = ' NOT'
else:
neg = ''
if isinstance(item[3], dict):
if AGG_OPS[item[2][1][0]] != '':
agg = ' ' + AGG_OPS[item[2][1][0]]
else:
agg = ''
if slot[item[2][1][1]] == '':
slot[item[2][1][1]] += (
cur_nest + agg + neg + ' ' + WHERE_OPS[item[1]])
else:
slot[item[2][1][1]] += (' ' + cur_nest + agg + neg + ' '
+ WHERE_OPS[item[1]])
slot = get_labels(item[3], slot, 'OP_SEL')
else:
if AGG_OPS[item[2][1][0]] != '':
agg = ' ' + AGG_OPS[item[2][1][0]] + ' '
else:
agg = ' '
if slot[item[2][1][1]] == '':
slot[item[2][1][1]] += (cur_nest + agg + WHERE_OPS[item[1]])
else:
slot[item[2][1][1]] += (' ' + cur_nest + agg
+ WHERE_OPS[item[1]])
else:
for index, item in enumerate(having):
if item[0] is True:
neg = ' NOT'
else:
neg = ''
if (index + 1 < len(having) and having[index + 1]) == 'or' or (
index - 1 >= 0 and having[index - 1] == 'or'):
if AGG_OPS[item[2][1][0]] != '':
agg = ' ' + AGG_OPS[item[2][1][0]]
else:
agg = ''
if isinstance(item[3], dict):
if slot[item[2][1][1]] == '':
slot[item[2][1][1]] += (
cur_nest + agg + neg + ' ' + WHERE_OPS[item[1]])
else:
slot[item[2][1][1]] += (' ' + cur_nest + agg + neg
+ ' ' + WHERE_OPS[item[1]])
slot = get_labels(item[3], slot, 'OP_SEL')
else:
if AGG_OPS[item[2][1][0]] != '':
agg = ' ' + AGG_OPS[item[2][1][0]] + ' '
else:
agg = ' '
if slot[item[2][1][1]] == '':
slot[item[2][1][1]] += (
cur_nest + ' OR' + agg + WHERE_OPS[item[1]])
else:
slot[item[2][1][1]] += (' ' + cur_nest + ' OR' + agg
+ WHERE_OPS[item[1]])
elif item == 'and' or item == 'or':
continue
else:
if isinstance(item[3], dict):
if slot[item[2][1][1]] == '':
slot[item[2][1][1]] += (
cur_nest + neg + ' ' + WHERE_OPS[item[1]])
else:
slot[item[2][1][1]] += (' ' + cur_nest + neg + ' '
+ WHERE_OPS[item[1]])
slot = get_labels(item[3], slot, 'OP_SEL')
else:
if AGG_OPS[item[2][1][0]] != '':
agg = ' ' + AGG_OPS[item[2][1][0]] + ' '
else:
agg = ' '
if slot[item[2][1][1]] == '':
slot[item[2][1][1]] += (
cur_nest + agg + WHERE_OPS[item[1]])
else:
slot[item[2][1][1]] += (' ' + cur_nest + agg
+ WHERE_OPS[item[1]])
return slot


def get_where_labels(where, slot, cur_nest):
if len(where) == 1:
item = where[0]
if item[0] is True:
neg = ' NOT'
else:
neg = ''
if isinstance(item[3], dict):
if slot[item[2][1][1]] == '':
slot[item[2][1][1]] += (
cur_nest + neg + ' ' + WHERE_OPS[item[1]])
else:
slot[item[2][1][1]] += (' ' + cur_nest + neg + ' '
+ WHERE_OPS[item[1]])
slot = get_labels(item[3], slot, 'OP_SEL')
else:
if slot[item[2][1][1]] == '':
slot[item[2][1][1]] += (
cur_nest + neg + ' ' + WHERE_OPS[item[1]])
else:
slot[item[2][1][1]] += (' ' + cur_nest + neg + ' '
+ WHERE_OPS[item[1]])
else:
for index, item in enumerate(where):
if item[0] is True:
neg = ' NOT'
else:
neg = ''
if (index + 1 < len(where) and where[index + 1]) == 'or' or (
index - 1 >= 0 and where[index - 1] == 'or'):
if isinstance(item[3], dict):
if slot[item[2][1][1]] == '':
slot[item[2][1][1]] += (
cur_nest + neg + ' ' + WHERE_OPS[item[1]])
else:
slot[item[2][1][1]] += (' ' + cur_nest + neg + ' '
+ WHERE_OPS[item[1]])
slot = get_labels(item[3], slot, 'OP_SEL')
else:
if slot[item[2][1][1]] == '':
slot[item[2][1][1]] += (
cur_nest + ' OR' + neg + ' ' + WHERE_OPS[item[1]])
else:
slot[item[2][1][1]] += (' ' + cur_nest + ' OR' + neg
+ ' ' + WHERE_OPS[item[1]])
elif item == 'and' or item == 'or':
continue
else:
if isinstance(item[3], dict):
if slot[item[2][1][1]] == '':
slot[item[2][1][1]] += (
cur_nest + neg + ' ' + WHERE_OPS[item[1]])
else:
slot[item[2][1][1]] += (' ' + cur_nest + neg + ' '
+ WHERE_OPS[item[1]])
slot = get_labels(item[3], slot, 'OP_SEL')
else:
if slot[item[2][1][1]] == '':
slot[item[2][1][1]] += (
cur_nest + neg + ' ' + WHERE_OPS[item[1]])
else:
slot[item[2][1][1]] += (' ' + cur_nest + neg + ' '
+ WHERE_OPS[item[1]])
return slot


def get_labels(sql_struct, slot, cur_nest):

if len(sql_struct['select']) > 0:
if cur_nest != '':
slot = get_select_labels(sql_struct['select'], slot,
cur_nest + ' SELECT')
else:
slot = get_select_labels(sql_struct['select'], slot, 'SELECT')

if sql_struct['from']:
if cur_nest != '':
slot = get_from_labels(sql_struct['from'], slot, 'FROM')
else:
slot = get_from_labels(sql_struct['from'], slot, 'FROM')

if len(sql_struct['where']) > 0:
if cur_nest != '':
slot = get_where_labels(sql_struct['where'], slot,
cur_nest + ' WHERE')
else:
slot = get_where_labels(sql_struct['where'], slot, 'WHERE')

if len(sql_struct['groupBy']) > 0:
if cur_nest != '':
slot = get_groupby_labels(sql_struct['groupBy'], slot,
cur_nest + ' GROUP_BY')
else:
slot = get_groupby_labels(sql_struct['groupBy'], slot, 'GROUP_BY')

if len(sql_struct['having']) > 0:
if cur_nest != '':
slot = get_having_labels(sql_struct['having'], slot,
cur_nest + ' HAVING')
else:
slot = get_having_labels(sql_struct['having'], slot, 'HAVING')

if len(sql_struct['orderBy']) > 0:
if cur_nest != '':
slot = get_orderby_labels(sql_struct['orderBy'],
sql_struct['limit'], slot,
cur_nest + ' ORDER_BY')
else:
slot = get_orderby_labels(sql_struct['orderBy'],
sql_struct['limit'], slot, 'ORDER_BY')

if sql_struct['intersect']:
if cur_nest != '':
slot = get_intersect_labels(sql_struct['intersect'], slot,
cur_nest + ' INTERSECT')
else:
slot = get_intersect_labels(sql_struct['intersect'], slot,
'INTERSECT')

if sql_struct['except']:
if cur_nest != '':
slot = get_except_labels(sql_struct['except'], slot,
cur_nest + ' EXCEPT')
else:
slot = get_except_labels(sql_struct['except'], slot, 'EXCEPT')

if sql_struct['union']:
if cur_nest != '':
slot = get_union_labels(sql_struct['union'], slot,
cur_nest + ' UNION')
else:
slot = get_union_labels(sql_struct['union'], slot, 'UNION')
return slot


def get_label(sql, column_len):
thelabel = []
slot = {}
for idx in range(column_len):
slot[idx] = ''
for value in get_labels(sql, slot, '').values():
thelabel.append(value)
return thelabel

+ 37
- 0
modelscope/preprocessors/star/fields/preprocess_dataset.py View File

@@ -0,0 +1,37 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from text2sql_lgesql.preprocess.parse_raw_json import Schema, get_schemas
from text2sql_lgesql.process_sql import get_sql

from modelscope.preprocessors.star.fields.parse import get_label


def preprocess_dataset(processor, dataset, output_tables, database_id, tables):

schemas, db_names, thetables = get_schemas(tables)
intables = output_tables[database_id]
schema = schemas[database_id]
table = thetables[database_id]
sql_label = []
if len(dataset['history']) == 0 or dataset['last_sql'] == '':
sql_label = [''] * len(intables['column_names'])
else:
schema = Schema(schema, table)
try:
sql_label = get_sql(schema, dataset['last_sql'])
except Exception:
sql_label = [''] * len(intables['column_names'])
sql_label = get_label(sql_label, len(table['column_names_original']))
theone = {'db_id': database_id}
theone['query'] = ''
theone['query_toks_no_value'] = []
theone['sql'] = {}
if len(dataset['history']) != 0:
theone['question'] = dataset['utterance'] + ' [CLS] ' + ' [CLS] '.join(
dataset['history'][::-1][:4])
theone['question_toks'] = theone['question'].split()
else:
theone['question'] = dataset['utterance']
theone['question_toks'] = dataset['utterance'].split()

return [theone], sql_label

+ 64
- 0
modelscope/preprocessors/star/fields/process_dataset.py View File

@@ -0,0 +1,64 @@
# Copyright (c) rhythmcao modified from https://github.com/rhythmcao/text2sql-lgesql.

import argparse
import os
import pickle
import sys
import time

import json
from text2sql_lgesql.asdl.asdl import ASDLGrammar
from text2sql_lgesql.asdl.transition_system import TransitionSystem

from modelscope.preprocessors.star.fields.common_utils import SubPreprocessor

sys.path.append(os.path.dirname(os.path.dirname(__file__)))


def process_example(processor, entry, db, trans, verbose=False):
# preprocess raw tokens, schema linking and subgraph extraction
entry = processor.pipeline(entry, db, verbose=verbose)
# generate target output actions
entry['ast'] = []
entry['actions'] = []
return entry


def process_tables(processor, tables_list, output_path=None, verbose=False):
tables = {}
for each in tables_list:
if verbose:
print('*************** Processing database %s **************' %
(each['db_id']))
tables[each['db_id']] = processor.preprocess_database(
each, verbose=verbose)
print('In total, process %d databases .' % (len(tables)))
if output_path is not None:
pickle.dump(tables, open(output_path, 'wb'))
return tables


def process_dataset(model_dir,
processor,
dataset,
tables,
output_path=None,
skip_large=False,
verbose=False):
grammar = ASDLGrammar.from_filepath(
os.path.join(model_dir, 'sql_asdl_v2.txt'))
trans = TransitionSystem.get_class_by_lang('sql')(grammar)
processed_dataset = []
for idx, entry in enumerate(dataset):
if skip_large and len(tables[entry['db_id']]['column_names']) > 100:
continue
if verbose:
print('*************** Processing %d-th sample **************' %
(idx))
entry = process_example(
processor, entry, tables[entry['db_id']], trans, verbose=verbose)
processed_dataset.append(entry)
if output_path is not None:
# serialize preprocessed dataset
pickle.dump(processed_dataset, open(output_path, 'wb'))
return processed_dataset

+ 1
- 0
modelscope/utils/constant.py View File

@@ -89,6 +89,7 @@ class NLPTasks(object):
zero_shot_classification = 'zero-shot-classification'
backbone = 'backbone'
text_error_correction = 'text-error-correction'
conversational_text_to_sql = 'conversational-text-to-sql'


class AudioTasks(object):


+ 1
- 0
requirements/nlp.txt View File

@@ -6,5 +6,6 @@ pai-easynlp
rouge_score<=0.0.4
seqeval
spacy>=2.3.5
text2sql_lgesql
tokenizers
transformers>=4.12.0

+ 97
- 0
tests/pipelines/test_conversational_text_to_sql.py View File

@@ -0,0 +1,97 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
from typing import List

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import StarForTextToSql
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import ConversationalTextToSqlPipeline
from modelscope.preprocessors import ConversationalTextToSqlPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class ConversationalTextToSql(unittest.TestCase):
model_id = 'damo/nlp_star_conversational-text-to-sql'
test_case = {
'database_id':
'employee_hire_evaluation',
'local_db_path':
None,
'utterance': [
"I'd like to see Shop names.", 'Which of these are hiring?',
'Which shop is hiring the highest number of employees? | do you want the name of the shop ? | Yes'
]
}

def tracking_and_print_results(
self, pipelines: List[ConversationalTextToSqlPipeline]):
for my_pipeline in pipelines:
last_sql, history = '', []
for item in self.test_case['utterance']:
case = {
'utterance': item,
'history': history,
'last_sql': last_sql,
'database_id': self.test_case['database_id'],
'local_db_path': self.test_case['local_db_path']
}
results = my_pipeline(case)
print({'question': item})
print(results)
last_sql = results['text']
history.append(item)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)
preprocessor = ConversationalTextToSqlPreprocessor(
model_dir=cache_path,
database_id=self.test_case['database_id'],
db_content=True)
model = StarForTextToSql(
model_dir=cache_path, config=preprocessor.config)

pipelines = [
ConversationalTextToSqlPipeline(
model=model, preprocessor=preprocessor),
pipeline(
task=Tasks.conversational_text_to_sql,
model=model,
preprocessor=preprocessor)
]
self.tracking_and_print_results(pipelines)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
preprocessor = ConversationalTextToSqlPreprocessor(
model_dir=model.model_dir)

pipelines = [
ConversationalTextToSqlPipeline(
model=model, preprocessor=preprocessor),
pipeline(
task=Tasks.conversational_text_to_sql,
model=model,
preprocessor=preprocessor)
]
self.tracking_and_print_results(pipelines)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
pipelines = [
pipeline(
task=Tasks.conversational_text_to_sql, model=self.model_id)
]
self.tracking_and_print_results(pipelines)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipelines = [pipeline(task=Tasks.conversational_text_to_sql)]
self.tracking_and_print_results(pipelines)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save