From 2dc32865247cbe03bbaf484aadfeb30a70fb5348 Mon Sep 17 00:00:00 2001 From: "piaoyu.lxy" Date: Thu, 11 Aug 2022 11:19:11 +0800 Subject: [PATCH] [to #42322933] add conversational_text_to_sql pipeline Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9580066 --- modelscope/metainfo.py | 3 + modelscope/models/nlp/__init__.py | 2 + modelscope/models/nlp/star_text_to_sql.py | 68 +++ modelscope/outputs.py | 6 + modelscope/pipelines/base.py | 3 + modelscope/pipelines/builder.py | 5 + modelscope/pipelines/nlp/__init__.py | 3 + .../conversational_text_to_sql_pipeline.py | 66 +++ modelscope/preprocessors/__init__.py | 2 + modelscope/preprocessors/star/__init__.py | 29 ++ ...conversational_text_to_sql_preprocessor.py | 111 +++++ .../preprocessors/star/fields/__init__.py | 6 + .../preprocessors/star/fields/common_utils.py | 471 ++++++++++++++++++ modelscope/preprocessors/star/fields/parse.py | 333 +++++++++++++ .../star/fields/preprocess_dataset.py | 37 ++ .../star/fields/process_dataset.py | 64 +++ modelscope/utils/constant.py | 1 + requirements/nlp.txt | 1 + .../test_conversational_text_to_sql.py | 97 ++++ 19 files changed, 1308 insertions(+) create mode 100644 modelscope/models/nlp/star_text_to_sql.py create mode 100644 modelscope/pipelines/nlp/conversational_text_to_sql_pipeline.py create mode 100644 modelscope/preprocessors/star/__init__.py create mode 100644 modelscope/preprocessors/star/conversational_text_to_sql_preprocessor.py create mode 100644 modelscope/preprocessors/star/fields/__init__.py create mode 100644 modelscope/preprocessors/star/fields/common_utils.py create mode 100644 modelscope/preprocessors/star/fields/parse.py create mode 100644 modelscope/preprocessors/star/fields/preprocess_dataset.py create mode 100644 modelscope/preprocessors/star/fields/process_dataset.py create mode 100644 tests/pipelines/test_conversational_text_to_sql.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index cbab0e0b..220b3c32 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -29,6 +29,7 @@ class Models(object): space_dst = 'space-dst' space_intent = 'space-intent' space_modeling = 'space-modeling' + star = 'star' tcrf = 'transformer-crf' bart = 'bart' gpt3 = 'gpt3' @@ -123,6 +124,7 @@ class Pipelines(object): dialog_state_tracking = 'dialog-state-tracking' zero_shot_classification = 'zero-shot-classification' text_error_correction = 'text-error-correction' + conversational_text_to_sql = 'conversational-text-to-sql' # audio tasks sambert_hifigan_tts = 'sambert-hifigan-tts' @@ -201,6 +203,7 @@ class Preprocessors(object): text_error_correction = 'text-error-correction' word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' fill_mask = 'fill-mask' + conversational_text_to_sql = 'conversational-text-to-sql' # audio preprocessor linear_aec_fbank = 'linear-aec-fbank' diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index 24e65ef1..3fd76f98 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -17,12 +17,14 @@ if TYPE_CHECKING: from .space import SpaceForDialogIntent from .space import SpaceForDialogModeling from .space import SpaceForDialogStateTracking + from .star_text_to_sql import StarForTextToSql from .task_models.task_model import SingleBackboneTaskModelBase from .bart_for_text_error_correction import BartForTextErrorCorrection from .gpt3 import GPT3ForTextGeneration else: _import_structure = { + 'star_text_to_sql': ['StarForTextToSql'], 'backbones': ['SbertModel'], 'heads': ['SequenceClassificationHead'], 'csanmt_for_translation': ['CsanmtForTranslation'], diff --git a/modelscope/models/nlp/star_text_to_sql.py b/modelscope/models/nlp/star_text_to_sql.py new file mode 100644 index 00000000..eef76e8a --- /dev/null +++ b/modelscope/models/nlp/star_text_to_sql.py @@ -0,0 +1,68 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Dict, Optional + +import torch +import torch.nn as nn +from text2sql_lgesql.asdl.asdl import ASDLGrammar +from text2sql_lgesql.asdl.transition_system import TransitionSystem +from text2sql_lgesql.model.model_constructor import Text2SQL +from text2sql_lgesql.utils.constants import GRAMMAR_FILEPATH + +from modelscope.metainfo import Models +from modelscope.models.base import Model, Tensor +from modelscope.models.builder import MODELS +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks + +__all__ = ['StarForTextToSql'] + + +@MODELS.register_module( + Tasks.conversational_text_to_sql, module_name=Models.star) +class StarForTextToSql(Model): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the star model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + self.beam_size = 5 + self.config = kwargs.pop( + 'config', + Config.from_file( + os.path.join(self.model_dir, ModelFile.CONFIGURATION))) + self.config.model.model_dir = model_dir + self.grammar = ASDLGrammar.from_filepath( + os.path.join(model_dir, 'sql_asdl_v2.txt')) + self.trans = TransitionSystem.get_class_by_lang('sql')(self.grammar) + self.arg = self.config.model + self.device = 'cuda' if \ + ('device' not in kwargs or kwargs['device'] == 'gpu') \ + and torch.cuda.is_available() else 'cpu' + self.model = Text2SQL(self.arg, self.trans) + check_point = torch.load( + open( + os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), 'rb'), + map_location=self.device) + self.model.load_state_dict(check_point['model']) + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + + Returns: + Dict[str, Tensor]: results + Example: + """ + self.model.eval() + hyps = self.model.parse(input['batch'], self.beam_size) # + db = input['batch'].examples[0].db + + predict = {'predict': hyps, 'db': db} + return predict diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 3dc3cc44..6a45b3e3 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -389,6 +389,12 @@ TASK_OUTPUTS = { # } Tasks.task_oriented_conversation: [OutputKeys.OUTPUT], + # conversational text-to-sql result for single sample + # { + # "text": "SELECT shop.Name FROM shop." + # } + Tasks.conversational_text_to_sql: [OutputKeys.TEXT], + # ============ audio tasks =================== # asr result for single sample # { "text": "每一天都要快乐喔"} diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index b1d82557..37d6f1e3 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -239,6 +239,7 @@ class Pipeline(ABC): """ from torch.utils.data.dataloader import default_collate from modelscope.preprocessors import InputFeatures + from text2sql_lgesql.utils.batch import Batch if isinstance(data, dict) or isinstance(data, Mapping): return type(data)( {k: self._collate_fn(v) @@ -259,6 +260,8 @@ class Pipeline(ABC): return data elif isinstance(data, InputFeatures): return data + elif isinstance(data, Batch): + return data else: import mmcv if isinstance(data, mmcv.parallel.data_container.DataContainer): diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 5c87bac5..12d8e4e9 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -50,6 +50,11 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/nlp_structbert_zero-shot-classification_chinese-base'), Tasks.task_oriented_conversation: (Pipelines.dialog_modeling, 'damo/nlp_space_dialog-modeling'), + Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, + 'damo/nlp_space_dialog-state-tracking'), + Tasks.conversational_text_to_sql: + (Pipelines.conversational_text_to_sql, + 'damo/nlp_star_conversational-text-to-sql'), Tasks.text_error_correction: (Pipelines.text_error_correction, 'damo/nlp_bart_text-error-correction_chinese'), diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index 1111f0d3..0cdb633c 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: + from .conversational_text_to_sql_pipeline import ConversationalTextToSqlPipeline from .dialog_intent_prediction_pipeline import DialogIntentPredictionPipeline from .dialog_modeling_pipeline import DialogModelingPipeline from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline @@ -22,6 +23,8 @@ if TYPE_CHECKING: else: _import_structure = { + 'conversational_text_to_sql_pipeline': + ['ConversationalTextToSqlPipeline'], 'dialog_intent_prediction_pipeline': ['DialogIntentPredictionPipeline'], 'dialog_modeling_pipeline': ['DialogModelingPipeline'], diff --git a/modelscope/pipelines/nlp/conversational_text_to_sql_pipeline.py b/modelscope/pipelines/nlp/conversational_text_to_sql_pipeline.py new file mode 100644 index 00000000..875c47fd --- /dev/null +++ b/modelscope/pipelines/nlp/conversational_text_to_sql_pipeline.py @@ -0,0 +1,66 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Union + +import torch +from text2sql_lgesql.utils.example import Example + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.models.nlp import StarForTextToSql +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import ConversationalTextToSqlPreprocessor +from modelscope.preprocessors.star.fields.common_utils import SubPreprocessor +from modelscope.preprocessors.star.fields.process_dataset import process_tables +from modelscope.utils.constant import Tasks + +__all__ = ['ConversationalTextToSqlPipeline'] + + +@PIPELINES.register_module( + Tasks.conversational_text_to_sql, + module_name=Pipelines.conversational_text_to_sql) +class ConversationalTextToSqlPipeline(Pipeline): + + def __init__(self, + model: Union[StarForTextToSql, str], + preprocessor: ConversationalTextToSqlPreprocessor = None, + **kwargs): + """use `model` and `preprocessor` to create a conversational text-to-sql prediction pipeline + + Args: + model (StarForTextToSql): a model instance + preprocessor (ConversationalTextToSqlPreprocessor): + a preprocessor instance + """ + model = model if isinstance( + model, StarForTextToSql) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = ConversationalTextToSqlPreprocessor(model.model_dir) + + preprocessor.device = 'cuda' if \ + ('device' not in kwargs or kwargs['device'] == 'gpu') \ + and torch.cuda.is_available() else 'cpu' + use_device = True if preprocessor.device == 'cuda' else False + preprocessor.processor = \ + SubPreprocessor(model_dir=model.model_dir, + db_content=True, + use_gpu=use_device) + preprocessor.output_tables = \ + process_tables(preprocessor.processor, + preprocessor.tables) + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + sql = Example.evaluator.obtain_sql(inputs['predict'][0], inputs['db']) + result = {OutputKeys.TEXT: sql} + return result diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index c73a6c4f..9a2adb04 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -27,6 +27,7 @@ if TYPE_CHECKING: DialogModelingPreprocessor, DialogStateTrackingPreprocessor) from .video import ReadVideoData + from .star import ConversationalTextToSqlPreprocessor else: _import_structure = { @@ -55,6 +56,7 @@ else: 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', 'DialogStateTrackingPreprocessor', 'InputFeatures' ], + 'star': ['ConversationalTextToSqlPreprocessor'], } import sys diff --git a/modelscope/preprocessors/star/__init__.py b/modelscope/preprocessors/star/__init__.py new file mode 100644 index 00000000..5a4bcea9 --- /dev/null +++ b/modelscope/preprocessors/star/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .conversational_text_to_sql_preprocessor import \ + ConversationalTextToSqlPreprocessor + from .fields import MultiWOZBPETextField, IntentBPETextField + +else: + _import_structure = { + 'conversational_text_to_sql_preprocessor': + ['ConversationalTextToSqlPreprocessor'], + 'fields': [ + 'get_label', 'SubPreprocessor', 'preprocess_dataset', + 'process_dataset' + ] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/preprocessors/star/conversational_text_to_sql_preprocessor.py b/modelscope/preprocessors/star/conversational_text_to_sql_preprocessor.py new file mode 100644 index 00000000..2032dcf7 --- /dev/null +++ b/modelscope/preprocessors/star/conversational_text_to_sql_preprocessor.py @@ -0,0 +1,111 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict + +import json +import torch +from text2sql_lgesql.preprocess.graph_utils import GraphProcessor +from text2sql_lgesql.preprocess.process_graphs import process_dataset_graph +from text2sql_lgesql.utils.batch import Batch +from text2sql_lgesql.utils.example import Example + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.preprocessors.star.fields.preprocess_dataset import \ + preprocess_dataset +from modelscope.preprocessors.star.fields.process_dataset import ( + process_dataset, process_tables) +from modelscope.utils.config import Config +from modelscope.utils.constant import Fields, ModelFile +from modelscope.utils.type_assert import type_assert + +__all__ = ['ConversationalTextToSqlPreprocessor'] + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.conversational_text_to_sql) +class ConversationalTextToSqlPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + model_dir (str): model path + """ + super().__init__(*args, **kwargs) + + self.model_dir: str = model_dir + + self.config = Config.from_file( + os.path.join(self.model_dir, ModelFile.CONFIGURATION)) + self.device = 'cuda' if \ + ('device' not in kwargs or kwargs['device'] == 'gpu') \ + and torch.cuda.is_available() else 'cpu' + self.processor = None + self.table_path = os.path.join(self.model_dir, 'tables.json') + self.tables = json.load(open(self.table_path, 'r')) + self.output_tables = None + self.path_cache = [] + self.graph_processor = GraphProcessor() + + Example.configuration( + plm=self.config['model']['plm'], + tables=self.output_tables, + table_path=os.path.join(model_dir, 'tables.json'), + model_dir=self.model_dir, + db_dir=os.path.join(model_dir, 'db')) + + @type_assert(object, dict) + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + """process the raw input data + + Args: + data (dict): + utterance: a sentence + last_sql: predicted sql of last utterance + Example: + utterance: 'Which of these are hiring?' + last_sql: '' + + Returns: + Dict[str, Any]: the preprocessed data + """ + # use local database + if data['local_db_path'] is not None and data[ + 'local_db_path'] not in self.path_cache: + self.path_cache.append(data['local_db_path']) + path = os.path.join(data['local_db_path'], 'tables.json') + self.tables = json.load(open(path, 'r')) + self.processor.db_dir = os.path.join(data['local_db_path'], 'db') + self.output_tables = process_tables(self.processor, self.tables) + Example.configuration( + plm=self.config['model']['plm'], + tables=self.output_tables, + table_path=path, + model_dir=self.model_dir, + db_dir=self.processor.db_dir) + + theresult, sql_label = \ + preprocess_dataset( + self.processor, + data, + self.output_tables, + data['database_id'], + self.tables + ) + output_dataset = process_dataset(self.model_dir, self.processor, + theresult, self.output_tables) + output_dataset = \ + process_dataset_graph( + self.graph_processor, + output_dataset, + self.output_tables, + method='lgesql' + ) + dev_ex = Example(output_dataset[0], + self.output_tables[data['database_id']], sql_label) + current_batch = Batch.from_example_list([dev_ex], + self.device, + train=False) + return {'batch': current_batch, 'db': data['database_id']} diff --git a/modelscope/preprocessors/star/fields/__init__.py b/modelscope/preprocessors/star/fields/__init__.py new file mode 100644 index 00000000..1e95a998 --- /dev/null +++ b/modelscope/preprocessors/star/fields/__init__.py @@ -0,0 +1,6 @@ +from modelscope.preprocessors.star.fields.common_utils import SubPreprocessor +from modelscope.preprocessors.star.fields.parse import get_label +from modelscope.preprocessors.star.fields.preprocess_dataset import \ + preprocess_dataset +from modelscope.preprocessors.star.fields.process_dataset import \ + process_dataset diff --git a/modelscope/preprocessors/star/fields/common_utils.py b/modelscope/preprocessors/star/fields/common_utils.py new file mode 100644 index 00000000..2d33b7ab --- /dev/null +++ b/modelscope/preprocessors/star/fields/common_utils.py @@ -0,0 +1,471 @@ +# Copyright (c) rhythmcao modified from https://github.com/rhythmcao/text2sql-lgesql. + +import os +import sqlite3 +from itertools import combinations, product + +import nltk +import numpy as np +from text2sql_lgesql.utils.constants import MAX_RELATIVE_DIST + +from modelscope.utils.logger import get_logger + +mwtokenizer = nltk.MWETokenizer(separator='') +mwtokenizer.add_mwe(('[', 'CLS', ']')) +logger = get_logger() + + +def is_number(s): + try: + float(s) + return True + except ValueError: + return False + + +def quote_normalization(question): + """ Normalize all usage of quotation marks into a separate \" """ + new_question, quotation_marks = [], [ + "'", '"', '`', '‘', '’', '“', '”', '``', "''", '‘‘', '’’' + ] + for idx, tok in enumerate(question): + if len(tok) > 2 and tok[0] in quotation_marks and tok[ + -1] in quotation_marks: + new_question += ["\"", tok[1:-1], "\""] + elif len(tok) > 2 and tok[0] in quotation_marks: + new_question += ["\"", tok[1:]] + elif len(tok) > 2 and tok[-1] in quotation_marks: + new_question += [tok[:-1], "\""] + elif tok in quotation_marks: + new_question.append("\"") + elif len(tok) == 2 and tok[0] in quotation_marks: + # special case: the length of entity value is 1 + if idx + 1 < len(question) and question[idx + + 1] in quotation_marks: + new_question += ["\"", tok[1]] + else: + new_question.append(tok) + else: + new_question.append(tok) + return new_question + + +class SubPreprocessor(): + + def __init__(self, model_dir, use_gpu=False, db_content=True): + super(SubPreprocessor, self).__init__() + self.model_dir = model_dir + self.db_dir = os.path.join(model_dir, 'db') + self.db_content = db_content + + from nltk import data + from nltk.corpus import stopwords + data.path.append(os.path.join(self.model_dir, 'nltk_data')) + self.stopwords = stopwords.words('english') + + import stanza + from stanza.resources import common + from stanza.pipeline import core + self.nlp = stanza.Pipeline( + 'en', + use_gpu=use_gpu, + dir=self.model_dir, + processors='tokenize,pos,lemma', + tokenize_pretokenized=True, + download_method=core.DownloadMethod.REUSE_RESOURCES) + self.nlp1 = stanza.Pipeline( + 'en', + use_gpu=use_gpu, + dir=self.model_dir, + processors='tokenize,pos,lemma', + download_method=core.DownloadMethod.REUSE_RESOURCES) + + def pipeline(self, entry: dict, db: dict, verbose: bool = False): + """ db should be preprocessed """ + entry = self.preprocess_question(entry, db, verbose=verbose) + entry = self.schema_linking(entry, db, verbose=verbose) + entry = self.extract_subgraph(entry, db, verbose=verbose) + return entry + + def preprocess_database(self, db: dict, verbose: bool = False): + table_toks, table_names = [], [] + for tab in db['table_names']: + doc = self.nlp1(tab) + tab = [w.lemma.lower() for s in doc.sentences for w in s.words] + table_toks.append(tab) + table_names.append(' '.join(tab)) + db['processed_table_toks'], db[ + 'processed_table_names'] = table_toks, table_names + column_toks, column_names = [], [] + for _, c in db['column_names']: + doc = self.nlp1(c) + c = [w.lemma.lower() for s in doc.sentences for w in s.words] + column_toks.append(c) + column_names.append(' '.join(c)) + db['processed_column_toks'], db[ + 'processed_column_names'] = column_toks, column_names + column2table = list(map(lambda x: x[0], db['column_names'])) + table2columns = [[] for _ in range(len(table_names))] + for col_id, col in enumerate(db['column_names']): + if col_id == 0: + continue + table2columns[col[0]].append(col_id) + db['column2table'], db['table2columns'] = column2table, table2columns + + t_num, c_num, dtype = len(db['table_names']), len( + db['column_names']), ' 0: + col1, col2 = list(zip(*db['foreign_keys'])) + col_mat[col1, col2], col_mat[ + col2, col1] = 'column-column-fk', 'column-column-fkr' + col_mat[0, list(range(c_num))] = '*-column-generic' + col_mat[list(range(c_num)), 0] = 'column-*-generic' + col_mat[0, 0] = '*-*-identity' + + # relations between tables and columns, t_num*c_num and c_num*t_num + tab_col_mat = np.array([['table-column-generic'] * c_num + for _ in range(t_num)], + dtype=dtype) + col_tab_mat = np.array([['column-table-generic'] * t_num + for _ in range(c_num)], + dtype=dtype) + cols, tabs = list( + zip(*list(map(lambda x: (x, column2table[x]), range(1, c_num))))) + col_tab_mat[cols, tabs], tab_col_mat[ + tabs, cols] = 'column-table-has', 'table-column-has' + if len(db['primary_keys']) > 0: + cols, tabs = list( + zip(*list( + map(lambda x: (x, column2table[x]), db['primary_keys'])))) + col_tab_mat[cols, tabs], tab_col_mat[ + tabs, cols] = 'column-table-pk', 'table-column-pk' + col_tab_mat[0, list(range(t_num))] = '*-table-generic' + tab_col_mat[list(range(t_num)), 0] = 'table-*-generic' + + relations = \ + np.concatenate([ + np.concatenate([tab_mat, tab_col_mat], axis=1), + np.concatenate([col_tab_mat, col_mat], axis=1) + ], axis=0) + db['relations'] = relations.tolist() + + if verbose: + print('Tables:', ', '.join(db['table_names'])) + print('Lemmatized:', ', '.join(table_names)) + print('Columns:', + ', '.join(list(map(lambda x: x[1], db['column_names'])))) + print('Lemmatized:', ', '.join(column_names), '\n') + return db + + def preprocess_question(self, + entry: dict, + db: dict, + verbose: bool = False): + """ Tokenize, lemmatize, lowercase question""" + # stanza tokenize, lemmatize and POS tag + question = ' '.join(quote_normalization(entry['question_toks'])) + + from nltk import data + data.path.append(os.path.join(self.model_dir, 'nltk_data')) + question = nltk.word_tokenize(question) + question = mwtokenizer.tokenize(question) + + doc = self.nlp([question]) + raw_toks = [w.text.lower() for s in doc.sentences for w in s.words] + toks = [w.lemma.lower() for s in doc.sentences for w in s.words] + pos_tags = [w.xpos for s in doc.sentences for w in s.words] + + entry['raw_question_toks'] = raw_toks + entry['processed_question_toks'] = toks + entry['pos_tags'] = pos_tags + + q_num, dtype = len(toks), ' 0: + orderBy = orderBy[1] + for val_unit in orderBy: + if val_unit[0] == 0: + col_unit = val_unit[1] + used_schema['column'].add(col_unit[1]) + else: + col_unit1, col_unit2 = val_unit[1:] + used_schema['column'].add(col_unit1[1]) + used_schema['column'].add(col_unit2[1]) + # union, intersect and except clause + if sql['intersect']: + used_schema = self.extract_subgraph_from_sql( + sql['intersect'], used_schema) + if sql['union']: + used_schema = self.extract_subgraph_from_sql( + sql['union'], used_schema) + if sql['except']: + used_schema = self.extract_subgraph_from_sql( + sql['except'], used_schema) + return used_schema + + def extract_subgraph_from_conds(self, conds: list, used_schema: dict): + if len(conds) == 0: + return used_schema + for cond in conds: + if cond in ['and', 'or']: + continue + val_unit, val1, val2 = cond[2:] + if val_unit[0] == 0: + col_unit = val_unit[1] + used_schema['column'].add(col_unit[1]) + else: + col_unit1, col_unit2 = val_unit[1:] + used_schema['column'].add(col_unit1[1]) + used_schema['column'].add(col_unit2[1]) + if type(val1) == list: + used_schema['column'].add(val1[1]) + elif type(val1) == dict: + used_schema = self.extract_subgraph_from_sql(val1, used_schema) + if type(val2) == list: + used_schema['column'].add(val1[1]) + elif type(val2) == dict: + used_schema = self.extract_subgraph_from_sql(val2, used_schema) + return used_schema + + def schema_linking(self, entry: dict, db: dict, verbose: bool = False): + raw_question_toks, question_toks = entry['raw_question_toks'], entry[ + 'processed_question_toks'] + table_toks, column_toks = db['processed_table_toks'], db[ + 'processed_column_toks'] + table_names, column_names = db['processed_table_names'], db[ + 'processed_column_names'] + q_num, t_num, c_num, dtype = len(question_toks), len(table_toks), len( + column_toks), ' 1 + and phrase in name): + q_tab_mat[range(i, j), idx] = 'question-table-partialmatch' + tab_q_mat[idx, range(i, j)] = 'table-question-partialmatch' + if verbose: + table_matched_pairs['partial'].append( + str((name, idx, phrase, i, j))) + + # relations between questions and columns + column_matched_pairs = {'partial': [], 'exact': [], 'value': []} + q_col_mat = np.array([['question-column-nomatch'] * c_num + for _ in range(q_num)], + dtype=dtype) + col_q_mat = np.array([['column-question-nomatch'] * q_num + for _ in range(c_num)], + dtype=dtype) + max_len = max([len(c) for c in column_toks]) + index_pairs = list( + filter(lambda x: x[1] - x[0] <= max_len, + combinations(range(q_num + 1), 2))) + index_pairs = sorted(index_pairs, key=lambda x: x[1] - x[0]) + for i, j in index_pairs: + phrase = ' '.join(question_toks[i:j]) + if phrase in self.stopwords: + continue + for idx, name in enumerate(column_names): + if phrase == name: + q_col_mat[range(i, j), idx] = 'question-column-exactmatch' + col_q_mat[idx, range(i, j)] = 'column-question-exactmatch' + if verbose: + column_matched_pairs['exact'].append( + str((name, idx, phrase, i, j))) + elif (j - i == 1 + and phrase in name.split()) or (j - i > 1 + and phrase in name): + q_col_mat[range(i, j), + idx] = 'question-column-partialmatch' + col_q_mat[idx, + range(i, j)] = 'column-question-partialmatch' + if verbose: + column_matched_pairs['partial'].append( + str((name, idx, phrase, i, j))) + if self.db_content: + db_file = os.path.join(self.db_dir, db['db_id'], + db['db_id'] + '.sqlite') + if not os.path.exists(db_file): + raise ValueError('[ERROR]: database file %s not found ...' % + (db_file)) + conn = sqlite3.connect(db_file) + conn.text_factory = lambda b: b.decode(errors='ignore') + conn.execute('pragma foreign_keys=ON') + for i, (tab_id, + col_name) in enumerate(db['column_names_original']): + if i == 0 or 'id' in column_toks[ + i]: # ignore * and special token 'id' + continue + tab_name = db['table_names_original'][tab_id] + try: + cursor = conn.execute("SELECT DISTINCT \"%s\" FROM \"%s\";" + % (col_name, tab_name)) + cell_values = cursor.fetchall() + cell_values = [str(each[0]) for each in cell_values] + cell_values = [[str(float(each))] if is_number(each) else + each.lower().split() + for each in cell_values] + except Exception as e: + print(e) + for j, word in enumerate(raw_question_toks): + word = str(float(word)) if is_number(word) else word + for c in cell_values: + if word in c and 'nomatch' in q_col_mat[ + j, i] and word not in self.stopwords: + q_col_mat[j, i] = 'question-column-valuematch' + col_q_mat[i, j] = 'column-question-valuematch' + if verbose: + column_matched_pairs['value'].append( + str((column_names[i], i, word, j, j + 1))) + break + conn.close() + + q_col_mat[:, 0] = 'question-*-generic' + col_q_mat[0] = '*-question-generic' + q_schema = np.concatenate([q_tab_mat, q_col_mat], axis=1) + schema_q = np.concatenate([tab_q_mat, col_q_mat], axis=0) + entry['schema_linking'] = (q_schema.tolist(), schema_q.tolist()) + + if verbose: + print('Question:', ' '.join(question_toks)) + print('Table matched: (table name, column id, \ + question span, start id, end id)') + print( + 'Exact match:', ', '.join(table_matched_pairs['exact']) + if table_matched_pairs['exact'] else 'empty') + print( + 'Partial match:', ', '.join(table_matched_pairs['partial']) + if table_matched_pairs['partial'] else 'empty') + print('Column matched: (column name, column id, \ + question span, start id, end id)') + print( + 'Exact match:', ', '.join(column_matched_pairs['exact']) + if column_matched_pairs['exact'] else 'empty') + print( + 'Partial match:', ', '.join(column_matched_pairs['partial']) + if column_matched_pairs['partial'] else 'empty') + print( + 'Value match:', ', '.join(column_matched_pairs['value']) + if column_matched_pairs['value'] else 'empty', '\n') + return entry diff --git a/modelscope/preprocessors/star/fields/parse.py b/modelscope/preprocessors/star/fields/parse.py new file mode 100644 index 00000000..02ae31a0 --- /dev/null +++ b/modelscope/preprocessors/star/fields/parse.py @@ -0,0 +1,333 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +CLAUSE_KEYWORDS = ('SELECT', 'FROM', 'WHERE', 'GROUP', 'ORDER', 'LIMIT', + 'INTERSECT', 'UNION', 'EXCEPT') +JOIN_KEYWORDS = ('JOIN', 'ON', 'AS') + +WHERE_OPS = ('NOT_IN', 'BETWEEN', '=', '>', '<', '>=', '<=', '!=', 'IN', + 'LIKE', 'IS', 'EXISTS') +UNIT_OPS = ('NONE', '-', '+', '*', '/') +AGG_OPS = ('', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG') +TABLE_TYPE = { + 'sql': 'sql', + 'table_unit': 'table_unit', +} +COND_OPS = ('AND', 'OR') +SQL_OPS = ('INTERSECT', 'UNION', 'EXCEPT') +ORDER_OPS = ('DESC', 'ASC') + + +def get_select_labels(select, slot, cur_nest): + for item in select[1]: + if AGG_OPS[item[0]] != '': + if slot[item[1][1][1]] == '': + slot[item[1][1][1]] += (cur_nest + ' ' + AGG_OPS[item[0]]) + else: + slot[item[1][1][1]] += (' ' + cur_nest + ' ' + + AGG_OPS[item[0]]) + else: + if slot[item[1][1][1]] == '': + slot[item[1][1][1]] += (cur_nest) + else: + slot[item[1][1][1]] += (' ' + cur_nest) + return slot + + +def get_groupby_labels(groupby, slot, cur_nest): + for item in groupby: + if slot[item[1]] == '': + slot[item[1]] += (cur_nest) + else: + slot[item[1]] += (' ' + cur_nest) + return slot + + +def get_orderby_labels(orderby, limit, slot, cur_nest): + if limit is None: + thelimit = '' + else: + thelimit = ' LIMIT' + for item in orderby[1]: + if AGG_OPS[item[1][0]] != '': + agg = ' ' + AGG_OPS[item[1][0]] + ' ' + else: + agg = ' ' + if slot[item[1][1]] == '': + slot[item[1][1]] += ( + cur_nest + agg + orderby[0].upper() + thelimit) + else: + slot[item[1][1]] += (' ' + cur_nest + agg + orderby[0].upper() + + thelimit) + + return slot + + +def get_intersect_labels(intersect, slot, cur_nest): + if isinstance(intersect, dict): + if cur_nest != '': + slot = get_labels(intersect, slot, cur_nest) + else: + slot = get_labels(intersect, slot, 'INTERSECT') + else: + return slot + return slot + + +def get_except_labels(texcept, slot, cur_nest): + if isinstance(texcept, dict): + if cur_nest != '': + slot = get_labels(texcept, slot, cur_nest) + else: + slot = get_labels(texcept, slot, 'EXCEPT') + else: + return slot + return slot + + +def get_union_labels(union, slot, cur_nest): + if isinstance(union, dict): + if cur_nest != '': + slot = get_labels(union, slot, cur_nest) + else: + slot = get_labels(union, slot, 'UNION') + else: + return slot + return slot + + +def get_from_labels(tfrom, slot, cur_nest): + if tfrom['table_units'][0][0] == 'sql': + slot = get_labels(tfrom['table_units'][0][1], slot, 'OP_SEL') + else: + return slot + return slot + + +def get_having_labels(having, slot, cur_nest): + if len(having) == 1: + item = having[0] + if item[0] is True: + neg = ' NOT' + else: + neg = '' + if isinstance(item[3], dict): + if AGG_OPS[item[2][1][0]] != '': + agg = ' ' + AGG_OPS[item[2][1][0]] + else: + agg = '' + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += ( + cur_nest + agg + neg + ' ' + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + agg + neg + ' ' + + WHERE_OPS[item[1]]) + slot = get_labels(item[3], slot, 'OP_SEL') + else: + if AGG_OPS[item[2][1][0]] != '': + agg = ' ' + AGG_OPS[item[2][1][0]] + ' ' + else: + agg = ' ' + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += (cur_nest + agg + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + agg + + WHERE_OPS[item[1]]) + else: + for index, item in enumerate(having): + if item[0] is True: + neg = ' NOT' + else: + neg = '' + if (index + 1 < len(having) and having[index + 1]) == 'or' or ( + index - 1 >= 0 and having[index - 1] == 'or'): + if AGG_OPS[item[2][1][0]] != '': + agg = ' ' + AGG_OPS[item[2][1][0]] + else: + agg = '' + if isinstance(item[3], dict): + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += ( + cur_nest + agg + neg + ' ' + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + agg + neg + + ' ' + WHERE_OPS[item[1]]) + slot = get_labels(item[3], slot, 'OP_SEL') + else: + if AGG_OPS[item[2][1][0]] != '': + agg = ' ' + AGG_OPS[item[2][1][0]] + ' ' + else: + agg = ' ' + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += ( + cur_nest + ' OR' + agg + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + ' OR' + agg + + WHERE_OPS[item[1]]) + elif item == 'and' or item == 'or': + continue + else: + if isinstance(item[3], dict): + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += ( + cur_nest + neg + ' ' + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + neg + ' ' + + WHERE_OPS[item[1]]) + slot = get_labels(item[3], slot, 'OP_SEL') + else: + if AGG_OPS[item[2][1][0]] != '': + agg = ' ' + AGG_OPS[item[2][1][0]] + ' ' + else: + agg = ' ' + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += ( + cur_nest + agg + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + agg + + WHERE_OPS[item[1]]) + return slot + + +def get_where_labels(where, slot, cur_nest): + if len(where) == 1: + item = where[0] + if item[0] is True: + neg = ' NOT' + else: + neg = '' + if isinstance(item[3], dict): + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += ( + cur_nest + neg + ' ' + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + neg + ' ' + + WHERE_OPS[item[1]]) + slot = get_labels(item[3], slot, 'OP_SEL') + else: + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += ( + cur_nest + neg + ' ' + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + neg + ' ' + + WHERE_OPS[item[1]]) + else: + for index, item in enumerate(where): + if item[0] is True: + neg = ' NOT' + else: + neg = '' + if (index + 1 < len(where) and where[index + 1]) == 'or' or ( + index - 1 >= 0 and where[index - 1] == 'or'): + if isinstance(item[3], dict): + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += ( + cur_nest + neg + ' ' + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + neg + ' ' + + WHERE_OPS[item[1]]) + slot = get_labels(item[3], slot, 'OP_SEL') + else: + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += ( + cur_nest + ' OR' + neg + ' ' + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + ' OR' + neg + + ' ' + WHERE_OPS[item[1]]) + elif item == 'and' or item == 'or': + continue + else: + if isinstance(item[3], dict): + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += ( + cur_nest + neg + ' ' + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + neg + ' ' + + WHERE_OPS[item[1]]) + slot = get_labels(item[3], slot, 'OP_SEL') + else: + if slot[item[2][1][1]] == '': + slot[item[2][1][1]] += ( + cur_nest + neg + ' ' + WHERE_OPS[item[1]]) + else: + slot[item[2][1][1]] += (' ' + cur_nest + neg + ' ' + + WHERE_OPS[item[1]]) + return slot + + +def get_labels(sql_struct, slot, cur_nest): + + if len(sql_struct['select']) > 0: + if cur_nest != '': + slot = get_select_labels(sql_struct['select'], slot, + cur_nest + ' SELECT') + else: + slot = get_select_labels(sql_struct['select'], slot, 'SELECT') + + if sql_struct['from']: + if cur_nest != '': + slot = get_from_labels(sql_struct['from'], slot, 'FROM') + else: + slot = get_from_labels(sql_struct['from'], slot, 'FROM') + + if len(sql_struct['where']) > 0: + if cur_nest != '': + slot = get_where_labels(sql_struct['where'], slot, + cur_nest + ' WHERE') + else: + slot = get_where_labels(sql_struct['where'], slot, 'WHERE') + + if len(sql_struct['groupBy']) > 0: + if cur_nest != '': + slot = get_groupby_labels(sql_struct['groupBy'], slot, + cur_nest + ' GROUP_BY') + else: + slot = get_groupby_labels(sql_struct['groupBy'], slot, 'GROUP_BY') + + if len(sql_struct['having']) > 0: + if cur_nest != '': + slot = get_having_labels(sql_struct['having'], slot, + cur_nest + ' HAVING') + else: + slot = get_having_labels(sql_struct['having'], slot, 'HAVING') + + if len(sql_struct['orderBy']) > 0: + if cur_nest != '': + slot = get_orderby_labels(sql_struct['orderBy'], + sql_struct['limit'], slot, + cur_nest + ' ORDER_BY') + else: + slot = get_orderby_labels(sql_struct['orderBy'], + sql_struct['limit'], slot, 'ORDER_BY') + + if sql_struct['intersect']: + if cur_nest != '': + slot = get_intersect_labels(sql_struct['intersect'], slot, + cur_nest + ' INTERSECT') + else: + slot = get_intersect_labels(sql_struct['intersect'], slot, + 'INTERSECT') + + if sql_struct['except']: + if cur_nest != '': + slot = get_except_labels(sql_struct['except'], slot, + cur_nest + ' EXCEPT') + else: + slot = get_except_labels(sql_struct['except'], slot, 'EXCEPT') + + if sql_struct['union']: + if cur_nest != '': + slot = get_union_labels(sql_struct['union'], slot, + cur_nest + ' UNION') + else: + slot = get_union_labels(sql_struct['union'], slot, 'UNION') + return slot + + +def get_label(sql, column_len): + thelabel = [] + slot = {} + for idx in range(column_len): + slot[idx] = '' + for value in get_labels(sql, slot, '').values(): + thelabel.append(value) + return thelabel diff --git a/modelscope/preprocessors/star/fields/preprocess_dataset.py b/modelscope/preprocessors/star/fields/preprocess_dataset.py new file mode 100644 index 00000000..6c84c0e7 --- /dev/null +++ b/modelscope/preprocessors/star/fields/preprocess_dataset.py @@ -0,0 +1,37 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from text2sql_lgesql.preprocess.parse_raw_json import Schema, get_schemas +from text2sql_lgesql.process_sql import get_sql + +from modelscope.preprocessors.star.fields.parse import get_label + + +def preprocess_dataset(processor, dataset, output_tables, database_id, tables): + + schemas, db_names, thetables = get_schemas(tables) + intables = output_tables[database_id] + schema = schemas[database_id] + table = thetables[database_id] + sql_label = [] + if len(dataset['history']) == 0 or dataset['last_sql'] == '': + sql_label = [''] * len(intables['column_names']) + else: + schema = Schema(schema, table) + try: + sql_label = get_sql(schema, dataset['last_sql']) + except Exception: + sql_label = [''] * len(intables['column_names']) + sql_label = get_label(sql_label, len(table['column_names_original'])) + theone = {'db_id': database_id} + theone['query'] = '' + theone['query_toks_no_value'] = [] + theone['sql'] = {} + if len(dataset['history']) != 0: + theone['question'] = dataset['utterance'] + ' [CLS] ' + ' [CLS] '.join( + dataset['history'][::-1][:4]) + theone['question_toks'] = theone['question'].split() + else: + theone['question'] = dataset['utterance'] + theone['question_toks'] = dataset['utterance'].split() + + return [theone], sql_label diff --git a/modelscope/preprocessors/star/fields/process_dataset.py b/modelscope/preprocessors/star/fields/process_dataset.py new file mode 100644 index 00000000..d8ac094a --- /dev/null +++ b/modelscope/preprocessors/star/fields/process_dataset.py @@ -0,0 +1,64 @@ +# Copyright (c) rhythmcao modified from https://github.com/rhythmcao/text2sql-lgesql. + +import argparse +import os +import pickle +import sys +import time + +import json +from text2sql_lgesql.asdl.asdl import ASDLGrammar +from text2sql_lgesql.asdl.transition_system import TransitionSystem + +from modelscope.preprocessors.star.fields.common_utils import SubPreprocessor + +sys.path.append(os.path.dirname(os.path.dirname(__file__))) + + +def process_example(processor, entry, db, trans, verbose=False): + # preprocess raw tokens, schema linking and subgraph extraction + entry = processor.pipeline(entry, db, verbose=verbose) + # generate target output actions + entry['ast'] = [] + entry['actions'] = [] + return entry + + +def process_tables(processor, tables_list, output_path=None, verbose=False): + tables = {} + for each in tables_list: + if verbose: + print('*************** Processing database %s **************' % + (each['db_id'])) + tables[each['db_id']] = processor.preprocess_database( + each, verbose=verbose) + print('In total, process %d databases .' % (len(tables))) + if output_path is not None: + pickle.dump(tables, open(output_path, 'wb')) + return tables + + +def process_dataset(model_dir, + processor, + dataset, + tables, + output_path=None, + skip_large=False, + verbose=False): + grammar = ASDLGrammar.from_filepath( + os.path.join(model_dir, 'sql_asdl_v2.txt')) + trans = TransitionSystem.get_class_by_lang('sql')(grammar) + processed_dataset = [] + for idx, entry in enumerate(dataset): + if skip_large and len(tables[entry['db_id']]['column_names']) > 100: + continue + if verbose: + print('*************** Processing %d-th sample **************' % + (idx)) + entry = process_example( + processor, entry, tables[entry['db_id']], trans, verbose=verbose) + processed_dataset.append(entry) + if output_path is not None: + # serialize preprocessed dataset + pickle.dump(processed_dataset, open(output_path, 'wb')) + return processed_dataset diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 538aa3db..be077551 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -89,6 +89,7 @@ class NLPTasks(object): zero_shot_classification = 'zero-shot-classification' backbone = 'backbone' text_error_correction = 'text-error-correction' + conversational_text_to_sql = 'conversational-text-to-sql' class AudioTasks(object): diff --git a/requirements/nlp.txt b/requirements/nlp.txt index 9bc543d7..6bd56aff 100644 --- a/requirements/nlp.txt +++ b/requirements/nlp.txt @@ -6,5 +6,6 @@ pai-easynlp rouge_score<=0.0.4 seqeval spacy>=2.3.5 +text2sql_lgesql tokenizers transformers>=4.12.0 diff --git a/tests/pipelines/test_conversational_text_to_sql.py b/tests/pipelines/test_conversational_text_to_sql.py new file mode 100644 index 00000000..67a4ce7b --- /dev/null +++ b/tests/pipelines/test_conversational_text_to_sql.py @@ -0,0 +1,97 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest +from typing import List + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import StarForTextToSql +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import ConversationalTextToSqlPipeline +from modelscope.preprocessors import ConversationalTextToSqlPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class ConversationalTextToSql(unittest.TestCase): + model_id = 'damo/nlp_star_conversational-text-to-sql' + test_case = { + 'database_id': + 'employee_hire_evaluation', + 'local_db_path': + None, + 'utterance': [ + "I'd like to see Shop names.", 'Which of these are hiring?', + 'Which shop is hiring the highest number of employees? | do you want the name of the shop ? | Yes' + ] + } + + def tracking_and_print_results( + self, pipelines: List[ConversationalTextToSqlPipeline]): + for my_pipeline in pipelines: + last_sql, history = '', [] + for item in self.test_case['utterance']: + case = { + 'utterance': item, + 'history': history, + 'last_sql': last_sql, + 'database_id': self.test_case['database_id'], + 'local_db_path': self.test_case['local_db_path'] + } + results = my_pipeline(case) + print({'question': item}) + print(results) + last_sql = results['text'] + history.append(item) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + preprocessor = ConversationalTextToSqlPreprocessor( + model_dir=cache_path, + database_id=self.test_case['database_id'], + db_content=True) + model = StarForTextToSql( + model_dir=cache_path, config=preprocessor.config) + + pipelines = [ + ConversationalTextToSqlPipeline( + model=model, preprocessor=preprocessor), + pipeline( + task=Tasks.conversational_text_to_sql, + model=model, + preprocessor=preprocessor) + ] + self.tracking_and_print_results(pipelines) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + preprocessor = ConversationalTextToSqlPreprocessor( + model_dir=model.model_dir) + + pipelines = [ + ConversationalTextToSqlPipeline( + model=model, preprocessor=preprocessor), + pipeline( + task=Tasks.conversational_text_to_sql, + model=model, + preprocessor=preprocessor) + ] + self.tracking_and_print_results(pipelines) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipelines = [ + pipeline( + task=Tasks.conversational_text_to_sql, model=self.model_id) + ] + self.tracking_and_print_results(pipelines) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipelines = [pipeline(task=Tasks.conversational_text_to_sql)] + self.tracking_and_print_results(pipelines) + + +if __name__ == '__main__': + unittest.main()