Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9580066master
| @@ -29,6 +29,7 @@ class Models(object): | |||||
| space_dst = 'space-dst' | space_dst = 'space-dst' | ||||
| space_intent = 'space-intent' | space_intent = 'space-intent' | ||||
| space_modeling = 'space-modeling' | space_modeling = 'space-modeling' | ||||
| star = 'star' | |||||
| tcrf = 'transformer-crf' | tcrf = 'transformer-crf' | ||||
| bart = 'bart' | bart = 'bart' | ||||
| gpt3 = 'gpt3' | gpt3 = 'gpt3' | ||||
| @@ -123,6 +124,7 @@ class Pipelines(object): | |||||
| dialog_state_tracking = 'dialog-state-tracking' | dialog_state_tracking = 'dialog-state-tracking' | ||||
| zero_shot_classification = 'zero-shot-classification' | zero_shot_classification = 'zero-shot-classification' | ||||
| text_error_correction = 'text-error-correction' | text_error_correction = 'text-error-correction' | ||||
| conversational_text_to_sql = 'conversational-text-to-sql' | |||||
| # audio tasks | # audio tasks | ||||
| sambert_hifigan_tts = 'sambert-hifigan-tts' | sambert_hifigan_tts = 'sambert-hifigan-tts' | ||||
| @@ -201,6 +203,7 @@ class Preprocessors(object): | |||||
| text_error_correction = 'text-error-correction' | text_error_correction = 'text-error-correction' | ||||
| word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' | word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' | ||||
| fill_mask = 'fill-mask' | fill_mask = 'fill-mask' | ||||
| conversational_text_to_sql = 'conversational-text-to-sql' | |||||
| # audio preprocessor | # audio preprocessor | ||||
| linear_aec_fbank = 'linear-aec-fbank' | linear_aec_fbank = 'linear-aec-fbank' | ||||
| @@ -17,12 +17,14 @@ if TYPE_CHECKING: | |||||
| from .space import SpaceForDialogIntent | from .space import SpaceForDialogIntent | ||||
| from .space import SpaceForDialogModeling | from .space import SpaceForDialogModeling | ||||
| from .space import SpaceForDialogStateTracking | from .space import SpaceForDialogStateTracking | ||||
| from .star_text_to_sql import StarForTextToSql | |||||
| from .task_models.task_model import SingleBackboneTaskModelBase | from .task_models.task_model import SingleBackboneTaskModelBase | ||||
| from .bart_for_text_error_correction import BartForTextErrorCorrection | from .bart_for_text_error_correction import BartForTextErrorCorrection | ||||
| from .gpt3 import GPT3ForTextGeneration | from .gpt3 import GPT3ForTextGeneration | ||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'star_text_to_sql': ['StarForTextToSql'], | |||||
| 'backbones': ['SbertModel'], | 'backbones': ['SbertModel'], | ||||
| 'heads': ['SequenceClassificationHead'], | 'heads': ['SequenceClassificationHead'], | ||||
| 'csanmt_for_translation': ['CsanmtForTranslation'], | 'csanmt_for_translation': ['CsanmtForTranslation'], | ||||
| @@ -0,0 +1,68 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| from typing import Dict, Optional | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from text2sql_lgesql.asdl.asdl import ASDLGrammar | |||||
| from text2sql_lgesql.asdl.transition_system import TransitionSystem | |||||
| from text2sql_lgesql.model.model_constructor import Text2SQL | |||||
| from text2sql_lgesql.utils.constants import GRAMMAR_FILEPATH | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import Model, Tensor | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| __all__ = ['StarForTextToSql'] | |||||
| @MODELS.register_module( | |||||
| Tasks.conversational_text_to_sql, module_name=Models.star) | |||||
| class StarForTextToSql(Model): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the star model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| """ | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| self.beam_size = 5 | |||||
| self.config = kwargs.pop( | |||||
| 'config', | |||||
| Config.from_file( | |||||
| os.path.join(self.model_dir, ModelFile.CONFIGURATION))) | |||||
| self.config.model.model_dir = model_dir | |||||
| self.grammar = ASDLGrammar.from_filepath( | |||||
| os.path.join(model_dir, 'sql_asdl_v2.txt')) | |||||
| self.trans = TransitionSystem.get_class_by_lang('sql')(self.grammar) | |||||
| self.arg = self.config.model | |||||
| self.device = 'cuda' if \ | |||||
| ('device' not in kwargs or kwargs['device'] == 'gpu') \ | |||||
| and torch.cuda.is_available() else 'cpu' | |||||
| self.model = Text2SQL(self.arg, self.trans) | |||||
| check_point = torch.load( | |||||
| open( | |||||
| os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), 'rb'), | |||||
| map_location=self.device) | |||||
| self.model.load_state_dict(check_point['model']) | |||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||||
| """return the result by the model | |||||
| Args: | |||||
| input (Dict[str, Tensor]): the preprocessed data | |||||
| Returns: | |||||
| Dict[str, Tensor]: results | |||||
| Example: | |||||
| """ | |||||
| self.model.eval() | |||||
| hyps = self.model.parse(input['batch'], self.beam_size) # | |||||
| db = input['batch'].examples[0].db | |||||
| predict = {'predict': hyps, 'db': db} | |||||
| return predict | |||||
| @@ -389,6 +389,12 @@ TASK_OUTPUTS = { | |||||
| # } | # } | ||||
| Tasks.task_oriented_conversation: [OutputKeys.OUTPUT], | Tasks.task_oriented_conversation: [OutputKeys.OUTPUT], | ||||
| # conversational text-to-sql result for single sample | |||||
| # { | |||||
| # "text": "SELECT shop.Name FROM shop." | |||||
| # } | |||||
| Tasks.conversational_text_to_sql: [OutputKeys.TEXT], | |||||
| # ============ audio tasks =================== | # ============ audio tasks =================== | ||||
| # asr result for single sample | # asr result for single sample | ||||
| # { "text": "每一天都要快乐喔"} | # { "text": "每一天都要快乐喔"} | ||||
| @@ -239,6 +239,7 @@ class Pipeline(ABC): | |||||
| """ | """ | ||||
| from torch.utils.data.dataloader import default_collate | from torch.utils.data.dataloader import default_collate | ||||
| from modelscope.preprocessors import InputFeatures | from modelscope.preprocessors import InputFeatures | ||||
| from text2sql_lgesql.utils.batch import Batch | |||||
| if isinstance(data, dict) or isinstance(data, Mapping): | if isinstance(data, dict) or isinstance(data, Mapping): | ||||
| return type(data)( | return type(data)( | ||||
| {k: self._collate_fn(v) | {k: self._collate_fn(v) | ||||
| @@ -259,6 +260,8 @@ class Pipeline(ABC): | |||||
| return data | return data | ||||
| elif isinstance(data, InputFeatures): | elif isinstance(data, InputFeatures): | ||||
| return data | return data | ||||
| elif isinstance(data, Batch): | |||||
| return data | |||||
| else: | else: | ||||
| import mmcv | import mmcv | ||||
| if isinstance(data, mmcv.parallel.data_container.DataContainer): | if isinstance(data, mmcv.parallel.data_container.DataContainer): | ||||
| @@ -50,6 +50,11 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/nlp_structbert_zero-shot-classification_chinese-base'), | 'damo/nlp_structbert_zero-shot-classification_chinese-base'), | ||||
| Tasks.task_oriented_conversation: (Pipelines.dialog_modeling, | Tasks.task_oriented_conversation: (Pipelines.dialog_modeling, | ||||
| 'damo/nlp_space_dialog-modeling'), | 'damo/nlp_space_dialog-modeling'), | ||||
| Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, | |||||
| 'damo/nlp_space_dialog-state-tracking'), | |||||
| Tasks.conversational_text_to_sql: | |||||
| (Pipelines.conversational_text_to_sql, | |||||
| 'damo/nlp_star_conversational-text-to-sql'), | |||||
| Tasks.text_error_correction: | Tasks.text_error_correction: | ||||
| (Pipelines.text_error_correction, | (Pipelines.text_error_correction, | ||||
| 'damo/nlp_bart_text-error-correction_chinese'), | 'damo/nlp_bart_text-error-correction_chinese'), | ||||
| @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | from modelscope.utils.import_utils import LazyImportModule | ||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from .conversational_text_to_sql_pipeline import ConversationalTextToSqlPipeline | |||||
| from .dialog_intent_prediction_pipeline import DialogIntentPredictionPipeline | from .dialog_intent_prediction_pipeline import DialogIntentPredictionPipeline | ||||
| from .dialog_modeling_pipeline import DialogModelingPipeline | from .dialog_modeling_pipeline import DialogModelingPipeline | ||||
| from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline | from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline | ||||
| @@ -22,6 +23,8 @@ if TYPE_CHECKING: | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'conversational_text_to_sql_pipeline': | |||||
| ['ConversationalTextToSqlPipeline'], | |||||
| 'dialog_intent_prediction_pipeline': | 'dialog_intent_prediction_pipeline': | ||||
| ['DialogIntentPredictionPipeline'], | ['DialogIntentPredictionPipeline'], | ||||
| 'dialog_modeling_pipeline': ['DialogModelingPipeline'], | 'dialog_modeling_pipeline': ['DialogModelingPipeline'], | ||||
| @@ -0,0 +1,66 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import Any, Dict, Union | |||||
| import torch | |||||
| from text2sql_lgesql.utils.example import Example | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.nlp import StarForTextToSql | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import ConversationalTextToSqlPreprocessor | |||||
| from modelscope.preprocessors.star.fields.common_utils import SubPreprocessor | |||||
| from modelscope.preprocessors.star.fields.process_dataset import process_tables | |||||
| from modelscope.utils.constant import Tasks | |||||
| __all__ = ['ConversationalTextToSqlPipeline'] | |||||
| @PIPELINES.register_module( | |||||
| Tasks.conversational_text_to_sql, | |||||
| module_name=Pipelines.conversational_text_to_sql) | |||||
| class ConversationalTextToSqlPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: Union[StarForTextToSql, str], | |||||
| preprocessor: ConversationalTextToSqlPreprocessor = None, | |||||
| **kwargs): | |||||
| """use `model` and `preprocessor` to create a conversational text-to-sql prediction pipeline | |||||
| Args: | |||||
| model (StarForTextToSql): a model instance | |||||
| preprocessor (ConversationalTextToSqlPreprocessor): | |||||
| a preprocessor instance | |||||
| """ | |||||
| model = model if isinstance( | |||||
| model, StarForTextToSql) else Model.from_pretrained(model) | |||||
| if preprocessor is None: | |||||
| preprocessor = ConversationalTextToSqlPreprocessor(model.model_dir) | |||||
| preprocessor.device = 'cuda' if \ | |||||
| ('device' not in kwargs or kwargs['device'] == 'gpu') \ | |||||
| and torch.cuda.is_available() else 'cpu' | |||||
| use_device = True if preprocessor.device == 'cuda' else False | |||||
| preprocessor.processor = \ | |||||
| SubPreprocessor(model_dir=model.model_dir, | |||||
| db_content=True, | |||||
| use_gpu=use_device) | |||||
| preprocessor.output_tables = \ | |||||
| process_tables(preprocessor.processor, | |||||
| preprocessor.tables) | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||||
| """process the prediction results | |||||
| Args: | |||||
| inputs (Dict[str, Any]): _description_ | |||||
| Returns: | |||||
| Dict[str, str]: the prediction results | |||||
| """ | |||||
| sql = Example.evaluator.obtain_sql(inputs['predict'][0], inputs['db']) | |||||
| result = {OutputKeys.TEXT: sql} | |||||
| return result | |||||
| @@ -27,6 +27,7 @@ if TYPE_CHECKING: | |||||
| DialogModelingPreprocessor, | DialogModelingPreprocessor, | ||||
| DialogStateTrackingPreprocessor) | DialogStateTrackingPreprocessor) | ||||
| from .video import ReadVideoData | from .video import ReadVideoData | ||||
| from .star import ConversationalTextToSqlPreprocessor | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| @@ -55,6 +56,7 @@ else: | |||||
| 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', | 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', | ||||
| 'DialogStateTrackingPreprocessor', 'InputFeatures' | 'DialogStateTrackingPreprocessor', 'InputFeatures' | ||||
| ], | ], | ||||
| 'star': ['ConversationalTextToSqlPreprocessor'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,29 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .conversational_text_to_sql_preprocessor import \ | |||||
| ConversationalTextToSqlPreprocessor | |||||
| from .fields import MultiWOZBPETextField, IntentBPETextField | |||||
| else: | |||||
| _import_structure = { | |||||
| 'conversational_text_to_sql_preprocessor': | |||||
| ['ConversationalTextToSqlPreprocessor'], | |||||
| 'fields': [ | |||||
| 'get_label', 'SubPreprocessor', 'preprocess_dataset', | |||||
| 'process_dataset' | |||||
| ] | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,111 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| from typing import Any, Dict | |||||
| import json | |||||
| import torch | |||||
| from text2sql_lgesql.preprocess.graph_utils import GraphProcessor | |||||
| from text2sql_lgesql.preprocess.process_graphs import process_dataset_graph | |||||
| from text2sql_lgesql.utils.batch import Batch | |||||
| from text2sql_lgesql.utils.example import Example | |||||
| from modelscope.metainfo import Preprocessors | |||||
| from modelscope.preprocessors.base import Preprocessor | |||||
| from modelscope.preprocessors.builder import PREPROCESSORS | |||||
| from modelscope.preprocessors.star.fields.preprocess_dataset import \ | |||||
| preprocess_dataset | |||||
| from modelscope.preprocessors.star.fields.process_dataset import ( | |||||
| process_dataset, process_tables) | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import Fields, ModelFile | |||||
| from modelscope.utils.type_assert import type_assert | |||||
| __all__ = ['ConversationalTextToSqlPreprocessor'] | |||||
| @PREPROCESSORS.register_module( | |||||
| Fields.nlp, module_name=Preprocessors.conversational_text_to_sql) | |||||
| class ConversationalTextToSqlPreprocessor(Preprocessor): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| self.model_dir: str = model_dir | |||||
| self.config = Config.from_file( | |||||
| os.path.join(self.model_dir, ModelFile.CONFIGURATION)) | |||||
| self.device = 'cuda' if \ | |||||
| ('device' not in kwargs or kwargs['device'] == 'gpu') \ | |||||
| and torch.cuda.is_available() else 'cpu' | |||||
| self.processor = None | |||||
| self.table_path = os.path.join(self.model_dir, 'tables.json') | |||||
| self.tables = json.load(open(self.table_path, 'r')) | |||||
| self.output_tables = None | |||||
| self.path_cache = [] | |||||
| self.graph_processor = GraphProcessor() | |||||
| Example.configuration( | |||||
| plm=self.config['model']['plm'], | |||||
| tables=self.output_tables, | |||||
| table_path=os.path.join(model_dir, 'tables.json'), | |||||
| model_dir=self.model_dir, | |||||
| db_dir=os.path.join(model_dir, 'db')) | |||||
| @type_assert(object, dict) | |||||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """process the raw input data | |||||
| Args: | |||||
| data (dict): | |||||
| utterance: a sentence | |||||
| last_sql: predicted sql of last utterance | |||||
| Example: | |||||
| utterance: 'Which of these are hiring?' | |||||
| last_sql: '' | |||||
| Returns: | |||||
| Dict[str, Any]: the preprocessed data | |||||
| """ | |||||
| # use local database | |||||
| if data['local_db_path'] is not None and data[ | |||||
| 'local_db_path'] not in self.path_cache: | |||||
| self.path_cache.append(data['local_db_path']) | |||||
| path = os.path.join(data['local_db_path'], 'tables.json') | |||||
| self.tables = json.load(open(path, 'r')) | |||||
| self.processor.db_dir = os.path.join(data['local_db_path'], 'db') | |||||
| self.output_tables = process_tables(self.processor, self.tables) | |||||
| Example.configuration( | |||||
| plm=self.config['model']['plm'], | |||||
| tables=self.output_tables, | |||||
| table_path=path, | |||||
| model_dir=self.model_dir, | |||||
| db_dir=self.processor.db_dir) | |||||
| theresult, sql_label = \ | |||||
| preprocess_dataset( | |||||
| self.processor, | |||||
| data, | |||||
| self.output_tables, | |||||
| data['database_id'], | |||||
| self.tables | |||||
| ) | |||||
| output_dataset = process_dataset(self.model_dir, self.processor, | |||||
| theresult, self.output_tables) | |||||
| output_dataset = \ | |||||
| process_dataset_graph( | |||||
| self.graph_processor, | |||||
| output_dataset, | |||||
| self.output_tables, | |||||
| method='lgesql' | |||||
| ) | |||||
| dev_ex = Example(output_dataset[0], | |||||
| self.output_tables[data['database_id']], sql_label) | |||||
| current_batch = Batch.from_example_list([dev_ex], | |||||
| self.device, | |||||
| train=False) | |||||
| return {'batch': current_batch, 'db': data['database_id']} | |||||
| @@ -0,0 +1,6 @@ | |||||
| from modelscope.preprocessors.star.fields.common_utils import SubPreprocessor | |||||
| from modelscope.preprocessors.star.fields.parse import get_label | |||||
| from modelscope.preprocessors.star.fields.preprocess_dataset import \ | |||||
| preprocess_dataset | |||||
| from modelscope.preprocessors.star.fields.process_dataset import \ | |||||
| process_dataset | |||||
| @@ -0,0 +1,471 @@ | |||||
| # Copyright (c) rhythmcao modified from https://github.com/rhythmcao/text2sql-lgesql. | |||||
| import os | |||||
| import sqlite3 | |||||
| from itertools import combinations, product | |||||
| import nltk | |||||
| import numpy as np | |||||
| from text2sql_lgesql.utils.constants import MAX_RELATIVE_DIST | |||||
| from modelscope.utils.logger import get_logger | |||||
| mwtokenizer = nltk.MWETokenizer(separator='') | |||||
| mwtokenizer.add_mwe(('[', 'CLS', ']')) | |||||
| logger = get_logger() | |||||
| def is_number(s): | |||||
| try: | |||||
| float(s) | |||||
| return True | |||||
| except ValueError: | |||||
| return False | |||||
| def quote_normalization(question): | |||||
| """ Normalize all usage of quotation marks into a separate \" """ | |||||
| new_question, quotation_marks = [], [ | |||||
| "'", '"', '`', '‘', '’', '“', '”', '``', "''", '‘‘', '’’' | |||||
| ] | |||||
| for idx, tok in enumerate(question): | |||||
| if len(tok) > 2 and tok[0] in quotation_marks and tok[ | |||||
| -1] in quotation_marks: | |||||
| new_question += ["\"", tok[1:-1], "\""] | |||||
| elif len(tok) > 2 and tok[0] in quotation_marks: | |||||
| new_question += ["\"", tok[1:]] | |||||
| elif len(tok) > 2 and tok[-1] in quotation_marks: | |||||
| new_question += [tok[:-1], "\""] | |||||
| elif tok in quotation_marks: | |||||
| new_question.append("\"") | |||||
| elif len(tok) == 2 and tok[0] in quotation_marks: | |||||
| # special case: the length of entity value is 1 | |||||
| if idx + 1 < len(question) and question[idx | |||||
| + 1] in quotation_marks: | |||||
| new_question += ["\"", tok[1]] | |||||
| else: | |||||
| new_question.append(tok) | |||||
| else: | |||||
| new_question.append(tok) | |||||
| return new_question | |||||
| class SubPreprocessor(): | |||||
| def __init__(self, model_dir, use_gpu=False, db_content=True): | |||||
| super(SubPreprocessor, self).__init__() | |||||
| self.model_dir = model_dir | |||||
| self.db_dir = os.path.join(model_dir, 'db') | |||||
| self.db_content = db_content | |||||
| from nltk import data | |||||
| from nltk.corpus import stopwords | |||||
| data.path.append(os.path.join(self.model_dir, 'nltk_data')) | |||||
| self.stopwords = stopwords.words('english') | |||||
| import stanza | |||||
| from stanza.resources import common | |||||
| from stanza.pipeline import core | |||||
| self.nlp = stanza.Pipeline( | |||||
| 'en', | |||||
| use_gpu=use_gpu, | |||||
| dir=self.model_dir, | |||||
| processors='tokenize,pos,lemma', | |||||
| tokenize_pretokenized=True, | |||||
| download_method=core.DownloadMethod.REUSE_RESOURCES) | |||||
| self.nlp1 = stanza.Pipeline( | |||||
| 'en', | |||||
| use_gpu=use_gpu, | |||||
| dir=self.model_dir, | |||||
| processors='tokenize,pos,lemma', | |||||
| download_method=core.DownloadMethod.REUSE_RESOURCES) | |||||
| def pipeline(self, entry: dict, db: dict, verbose: bool = False): | |||||
| """ db should be preprocessed """ | |||||
| entry = self.preprocess_question(entry, db, verbose=verbose) | |||||
| entry = self.schema_linking(entry, db, verbose=verbose) | |||||
| entry = self.extract_subgraph(entry, db, verbose=verbose) | |||||
| return entry | |||||
| def preprocess_database(self, db: dict, verbose: bool = False): | |||||
| table_toks, table_names = [], [] | |||||
| for tab in db['table_names']: | |||||
| doc = self.nlp1(tab) | |||||
| tab = [w.lemma.lower() for s in doc.sentences for w in s.words] | |||||
| table_toks.append(tab) | |||||
| table_names.append(' '.join(tab)) | |||||
| db['processed_table_toks'], db[ | |||||
| 'processed_table_names'] = table_toks, table_names | |||||
| column_toks, column_names = [], [] | |||||
| for _, c in db['column_names']: | |||||
| doc = self.nlp1(c) | |||||
| c = [w.lemma.lower() for s in doc.sentences for w in s.words] | |||||
| column_toks.append(c) | |||||
| column_names.append(' '.join(c)) | |||||
| db['processed_column_toks'], db[ | |||||
| 'processed_column_names'] = column_toks, column_names | |||||
| column2table = list(map(lambda x: x[0], db['column_names'])) | |||||
| table2columns = [[] for _ in range(len(table_names))] | |||||
| for col_id, col in enumerate(db['column_names']): | |||||
| if col_id == 0: | |||||
| continue | |||||
| table2columns[col[0]].append(col_id) | |||||
| db['column2table'], db['table2columns'] = column2table, table2columns | |||||
| t_num, c_num, dtype = len(db['table_names']), len( | |||||
| db['column_names']), '<U100' | |||||
| tab_mat = np.array([['table-table-generic'] * t_num | |||||
| for _ in range(t_num)], | |||||
| dtype=dtype) | |||||
| table_fks = set( | |||||
| map(lambda pair: (column2table[pair[0]], column2table[pair[1]]), | |||||
| db['foreign_keys'])) | |||||
| for (tab1, tab2) in table_fks: | |||||
| if (tab2, tab1) in table_fks: | |||||
| tab_mat[tab1, tab2], tab_mat[ | |||||
| tab2, tab1] = 'table-table-fkb', 'table-table-fkb' | |||||
| else: | |||||
| tab_mat[tab1, tab2], tab_mat[ | |||||
| tab2, tab1] = 'table-table-fk', 'table-table-fkr' | |||||
| tab_mat[list(range(t_num)), | |||||
| list(range(t_num))] = 'table-table-identity' | |||||
| col_mat = np.array([['column-column-generic'] * c_num | |||||
| for _ in range(c_num)], | |||||
| dtype=dtype) | |||||
| for i in range(t_num): | |||||
| col_ids = [idx for idx, t in enumerate(column2table) if t == i] | |||||
| col1, col2 = list(zip(*list(product(col_ids, col_ids)))) | |||||
| col_mat[col1, col2] = 'column-column-sametable' | |||||
| col_mat[list(range(c_num)), | |||||
| list(range(c_num))] = 'column-column-identity' | |||||
| if len(db['foreign_keys']) > 0: | |||||
| col1, col2 = list(zip(*db['foreign_keys'])) | |||||
| col_mat[col1, col2], col_mat[ | |||||
| col2, col1] = 'column-column-fk', 'column-column-fkr' | |||||
| col_mat[0, list(range(c_num))] = '*-column-generic' | |||||
| col_mat[list(range(c_num)), 0] = 'column-*-generic' | |||||
| col_mat[0, 0] = '*-*-identity' | |||||
| # relations between tables and columns, t_num*c_num and c_num*t_num | |||||
| tab_col_mat = np.array([['table-column-generic'] * c_num | |||||
| for _ in range(t_num)], | |||||
| dtype=dtype) | |||||
| col_tab_mat = np.array([['column-table-generic'] * t_num | |||||
| for _ in range(c_num)], | |||||
| dtype=dtype) | |||||
| cols, tabs = list( | |||||
| zip(*list(map(lambda x: (x, column2table[x]), range(1, c_num))))) | |||||
| col_tab_mat[cols, tabs], tab_col_mat[ | |||||
| tabs, cols] = 'column-table-has', 'table-column-has' | |||||
| if len(db['primary_keys']) > 0: | |||||
| cols, tabs = list( | |||||
| zip(*list( | |||||
| map(lambda x: (x, column2table[x]), db['primary_keys'])))) | |||||
| col_tab_mat[cols, tabs], tab_col_mat[ | |||||
| tabs, cols] = 'column-table-pk', 'table-column-pk' | |||||
| col_tab_mat[0, list(range(t_num))] = '*-table-generic' | |||||
| tab_col_mat[list(range(t_num)), 0] = 'table-*-generic' | |||||
| relations = \ | |||||
| np.concatenate([ | |||||
| np.concatenate([tab_mat, tab_col_mat], axis=1), | |||||
| np.concatenate([col_tab_mat, col_mat], axis=1) | |||||
| ], axis=0) | |||||
| db['relations'] = relations.tolist() | |||||
| if verbose: | |||||
| print('Tables:', ', '.join(db['table_names'])) | |||||
| print('Lemmatized:', ', '.join(table_names)) | |||||
| print('Columns:', | |||||
| ', '.join(list(map(lambda x: x[1], db['column_names'])))) | |||||
| print('Lemmatized:', ', '.join(column_names), '\n') | |||||
| return db | |||||
| def preprocess_question(self, | |||||
| entry: dict, | |||||
| db: dict, | |||||
| verbose: bool = False): | |||||
| """ Tokenize, lemmatize, lowercase question""" | |||||
| # stanza tokenize, lemmatize and POS tag | |||||
| question = ' '.join(quote_normalization(entry['question_toks'])) | |||||
| from nltk import data | |||||
| data.path.append(os.path.join(self.model_dir, 'nltk_data')) | |||||
| question = nltk.word_tokenize(question) | |||||
| question = mwtokenizer.tokenize(question) | |||||
| doc = self.nlp([question]) | |||||
| raw_toks = [w.text.lower() for s in doc.sentences for w in s.words] | |||||
| toks = [w.lemma.lower() for s in doc.sentences for w in s.words] | |||||
| pos_tags = [w.xpos for s in doc.sentences for w in s.words] | |||||
| entry['raw_question_toks'] = raw_toks | |||||
| entry['processed_question_toks'] = toks | |||||
| entry['pos_tags'] = pos_tags | |||||
| q_num, dtype = len(toks), '<U100' | |||||
| if q_num <= MAX_RELATIVE_DIST + 1: | |||||
| dist_vec = [ | |||||
| 'question-question-dist' | |||||
| + str(i) if i != 0 else 'question-question-identity' | |||||
| for i in range(-MAX_RELATIVE_DIST, MAX_RELATIVE_DIST + 1, 1) | |||||
| ] | |||||
| starting = MAX_RELATIVE_DIST | |||||
| else: | |||||
| dist_vec = ['question-question-generic'] \ | |||||
| * (q_num - MAX_RELATIVE_DIST - 1) + \ | |||||
| [ | |||||
| 'question-question-dist' + str(i) | |||||
| if i != 0 else 'question-question-identity' | |||||
| for i in range(- MAX_RELATIVE_DIST, MAX_RELATIVE_DIST + 1, | |||||
| 1)]\ | |||||
| + ['question-question-generic'] \ | |||||
| * (q_num - MAX_RELATIVE_DIST - 1) | |||||
| starting = q_num - 1 | |||||
| list_data = \ | |||||
| [dist_vec[starting - i:starting - i + q_num] for i in range(q_num)] | |||||
| q_mat = \ | |||||
| np.array( | |||||
| list_data, | |||||
| dtype=dtype | |||||
| ) | |||||
| entry['relations'] = q_mat.tolist() | |||||
| if verbose: | |||||
| print('Question:', entry['question']) | |||||
| print('Tokenized:', ' '.join(entry['raw_question_toks'])) | |||||
| print('Lemmatized:', ' '.join(entry['processed_question_toks'])) | |||||
| print('Pos tags:', ' '.join(entry['pos_tags']), '\n') | |||||
| return entry | |||||
| def extract_subgraph(self, entry: dict, db: dict, verbose: bool = False): | |||||
| used_schema = {'table': set(), 'column': set()} | |||||
| entry['used_tables'] = sorted(list(used_schema['table'])) | |||||
| entry['used_columns'] = sorted(list(used_schema['column'])) | |||||
| if verbose: | |||||
| print('Used tables:', entry['used_tables']) | |||||
| print('Used columns:', entry['used_columns'], '\n') | |||||
| return entry | |||||
| def extract_subgraph_from_sql(self, sql: dict, used_schema: dict): | |||||
| select_items = sql['select'][1] | |||||
| # select clause | |||||
| for _, val_unit in select_items: | |||||
| if val_unit[0] == 0: | |||||
| col_unit = val_unit[1] | |||||
| used_schema['column'].add(col_unit[1]) | |||||
| else: | |||||
| col_unit1, col_unit2 = val_unit[1:] | |||||
| used_schema['column'].add(col_unit1[1]) | |||||
| used_schema['column'].add(col_unit2[1]) | |||||
| # from clause conds | |||||
| table_units = sql['from']['table_units'] | |||||
| for _, t in table_units: | |||||
| if type(t) == dict: | |||||
| used_schema = self.extract_subgraph_from_sql(t, used_schema) | |||||
| else: | |||||
| used_schema['table'].add(t) | |||||
| # from, where and having conds | |||||
| used_schema = self.extract_subgraph_from_conds(sql['from']['conds'], | |||||
| used_schema) | |||||
| used_schema = self.extract_subgraph_from_conds(sql['where'], | |||||
| used_schema) | |||||
| used_schema = self.extract_subgraph_from_conds(sql['having'], | |||||
| used_schema) | |||||
| # groupBy and orderBy clause | |||||
| groupBy = sql['groupBy'] | |||||
| for col_unit in groupBy: | |||||
| used_schema['column'].add(col_unit[1]) | |||||
| orderBy = sql['orderBy'] | |||||
| if len(orderBy) > 0: | |||||
| orderBy = orderBy[1] | |||||
| for val_unit in orderBy: | |||||
| if val_unit[0] == 0: | |||||
| col_unit = val_unit[1] | |||||
| used_schema['column'].add(col_unit[1]) | |||||
| else: | |||||
| col_unit1, col_unit2 = val_unit[1:] | |||||
| used_schema['column'].add(col_unit1[1]) | |||||
| used_schema['column'].add(col_unit2[1]) | |||||
| # union, intersect and except clause | |||||
| if sql['intersect']: | |||||
| used_schema = self.extract_subgraph_from_sql( | |||||
| sql['intersect'], used_schema) | |||||
| if sql['union']: | |||||
| used_schema = self.extract_subgraph_from_sql( | |||||
| sql['union'], used_schema) | |||||
| if sql['except']: | |||||
| used_schema = self.extract_subgraph_from_sql( | |||||
| sql['except'], used_schema) | |||||
| return used_schema | |||||
| def extract_subgraph_from_conds(self, conds: list, used_schema: dict): | |||||
| if len(conds) == 0: | |||||
| return used_schema | |||||
| for cond in conds: | |||||
| if cond in ['and', 'or']: | |||||
| continue | |||||
| val_unit, val1, val2 = cond[2:] | |||||
| if val_unit[0] == 0: | |||||
| col_unit = val_unit[1] | |||||
| used_schema['column'].add(col_unit[1]) | |||||
| else: | |||||
| col_unit1, col_unit2 = val_unit[1:] | |||||
| used_schema['column'].add(col_unit1[1]) | |||||
| used_schema['column'].add(col_unit2[1]) | |||||
| if type(val1) == list: | |||||
| used_schema['column'].add(val1[1]) | |||||
| elif type(val1) == dict: | |||||
| used_schema = self.extract_subgraph_from_sql(val1, used_schema) | |||||
| if type(val2) == list: | |||||
| used_schema['column'].add(val1[1]) | |||||
| elif type(val2) == dict: | |||||
| used_schema = self.extract_subgraph_from_sql(val2, used_schema) | |||||
| return used_schema | |||||
| def schema_linking(self, entry: dict, db: dict, verbose: bool = False): | |||||
| raw_question_toks, question_toks = entry['raw_question_toks'], entry[ | |||||
| 'processed_question_toks'] | |||||
| table_toks, column_toks = db['processed_table_toks'], db[ | |||||
| 'processed_column_toks'] | |||||
| table_names, column_names = db['processed_table_names'], db[ | |||||
| 'processed_column_names'] | |||||
| q_num, t_num, c_num, dtype = len(question_toks), len(table_toks), len( | |||||
| column_toks), '<U100' | |||||
| # relations between questions and tables, q_num*t_num and t_num*q_num | |||||
| table_matched_pairs = {'partial': [], 'exact': []} | |||||
| q_tab_mat = np.array([['question-table-nomatch'] * t_num | |||||
| for _ in range(q_num)], | |||||
| dtype=dtype) | |||||
| tab_q_mat = np.array([['table-question-nomatch'] * q_num | |||||
| for _ in range(t_num)], | |||||
| dtype=dtype) | |||||
| max_len = max([len(t) for t in table_toks]) | |||||
| index_pairs = list( | |||||
| filter(lambda x: x[1] - x[0] <= max_len, | |||||
| combinations(range(q_num + 1), 2))) | |||||
| index_pairs = sorted(index_pairs, key=lambda x: x[1] - x[0]) | |||||
| for i, j in index_pairs: | |||||
| phrase = ' '.join(question_toks[i:j]) | |||||
| if phrase in self.stopwords: | |||||
| continue | |||||
| for idx, name in enumerate(table_names): | |||||
| if phrase == name: | |||||
| q_tab_mat[range(i, j), idx] = 'question-table-exactmatch' | |||||
| tab_q_mat[idx, range(i, j)] = 'table-question-exactmatch' | |||||
| if verbose: | |||||
| table_matched_pairs['exact'].append( | |||||
| str((name, idx, phrase, i, j))) | |||||
| elif (j - i == 1 | |||||
| and phrase in name.split()) or (j - i > 1 | |||||
| and phrase in name): | |||||
| q_tab_mat[range(i, j), idx] = 'question-table-partialmatch' | |||||
| tab_q_mat[idx, range(i, j)] = 'table-question-partialmatch' | |||||
| if verbose: | |||||
| table_matched_pairs['partial'].append( | |||||
| str((name, idx, phrase, i, j))) | |||||
| # relations between questions and columns | |||||
| column_matched_pairs = {'partial': [], 'exact': [], 'value': []} | |||||
| q_col_mat = np.array([['question-column-nomatch'] * c_num | |||||
| for _ in range(q_num)], | |||||
| dtype=dtype) | |||||
| col_q_mat = np.array([['column-question-nomatch'] * q_num | |||||
| for _ in range(c_num)], | |||||
| dtype=dtype) | |||||
| max_len = max([len(c) for c in column_toks]) | |||||
| index_pairs = list( | |||||
| filter(lambda x: x[1] - x[0] <= max_len, | |||||
| combinations(range(q_num + 1), 2))) | |||||
| index_pairs = sorted(index_pairs, key=lambda x: x[1] - x[0]) | |||||
| for i, j in index_pairs: | |||||
| phrase = ' '.join(question_toks[i:j]) | |||||
| if phrase in self.stopwords: | |||||
| continue | |||||
| for idx, name in enumerate(column_names): | |||||
| if phrase == name: | |||||
| q_col_mat[range(i, j), idx] = 'question-column-exactmatch' | |||||
| col_q_mat[idx, range(i, j)] = 'column-question-exactmatch' | |||||
| if verbose: | |||||
| column_matched_pairs['exact'].append( | |||||
| str((name, idx, phrase, i, j))) | |||||
| elif (j - i == 1 | |||||
| and phrase in name.split()) or (j - i > 1 | |||||
| and phrase in name): | |||||
| q_col_mat[range(i, j), | |||||
| idx] = 'question-column-partialmatch' | |||||
| col_q_mat[idx, | |||||
| range(i, j)] = 'column-question-partialmatch' | |||||
| if verbose: | |||||
| column_matched_pairs['partial'].append( | |||||
| str((name, idx, phrase, i, j))) | |||||
| if self.db_content: | |||||
| db_file = os.path.join(self.db_dir, db['db_id'], | |||||
| db['db_id'] + '.sqlite') | |||||
| if not os.path.exists(db_file): | |||||
| raise ValueError('[ERROR]: database file %s not found ...' % | |||||
| (db_file)) | |||||
| conn = sqlite3.connect(db_file) | |||||
| conn.text_factory = lambda b: b.decode(errors='ignore') | |||||
| conn.execute('pragma foreign_keys=ON') | |||||
| for i, (tab_id, | |||||
| col_name) in enumerate(db['column_names_original']): | |||||
| if i == 0 or 'id' in column_toks[ | |||||
| i]: # ignore * and special token 'id' | |||||
| continue | |||||
| tab_name = db['table_names_original'][tab_id] | |||||
| try: | |||||
| cursor = conn.execute("SELECT DISTINCT \"%s\" FROM \"%s\";" | |||||
| % (col_name, tab_name)) | |||||
| cell_values = cursor.fetchall() | |||||
| cell_values = [str(each[0]) for each in cell_values] | |||||
| cell_values = [[str(float(each))] if is_number(each) else | |||||
| each.lower().split() | |||||
| for each in cell_values] | |||||
| except Exception as e: | |||||
| print(e) | |||||
| for j, word in enumerate(raw_question_toks): | |||||
| word = str(float(word)) if is_number(word) else word | |||||
| for c in cell_values: | |||||
| if word in c and 'nomatch' in q_col_mat[ | |||||
| j, i] and word not in self.stopwords: | |||||
| q_col_mat[j, i] = 'question-column-valuematch' | |||||
| col_q_mat[i, j] = 'column-question-valuematch' | |||||
| if verbose: | |||||
| column_matched_pairs['value'].append( | |||||
| str((column_names[i], i, word, j, j + 1))) | |||||
| break | |||||
| conn.close() | |||||
| q_col_mat[:, 0] = 'question-*-generic' | |||||
| col_q_mat[0] = '*-question-generic' | |||||
| q_schema = np.concatenate([q_tab_mat, q_col_mat], axis=1) | |||||
| schema_q = np.concatenate([tab_q_mat, col_q_mat], axis=0) | |||||
| entry['schema_linking'] = (q_schema.tolist(), schema_q.tolist()) | |||||
| if verbose: | |||||
| print('Question:', ' '.join(question_toks)) | |||||
| print('Table matched: (table name, column id, \ | |||||
| question span, start id, end id)') | |||||
| print( | |||||
| 'Exact match:', ', '.join(table_matched_pairs['exact']) | |||||
| if table_matched_pairs['exact'] else 'empty') | |||||
| print( | |||||
| 'Partial match:', ', '.join(table_matched_pairs['partial']) | |||||
| if table_matched_pairs['partial'] else 'empty') | |||||
| print('Column matched: (column name, column id, \ | |||||
| question span, start id, end id)') | |||||
| print( | |||||
| 'Exact match:', ', '.join(column_matched_pairs['exact']) | |||||
| if column_matched_pairs['exact'] else 'empty') | |||||
| print( | |||||
| 'Partial match:', ', '.join(column_matched_pairs['partial']) | |||||
| if column_matched_pairs['partial'] else 'empty') | |||||
| print( | |||||
| 'Value match:', ', '.join(column_matched_pairs['value']) | |||||
| if column_matched_pairs['value'] else 'empty', '\n') | |||||
| return entry | |||||
| @@ -0,0 +1,333 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| CLAUSE_KEYWORDS = ('SELECT', 'FROM', 'WHERE', 'GROUP', 'ORDER', 'LIMIT', | |||||
| 'INTERSECT', 'UNION', 'EXCEPT') | |||||
| JOIN_KEYWORDS = ('JOIN', 'ON', 'AS') | |||||
| WHERE_OPS = ('NOT_IN', 'BETWEEN', '=', '>', '<', '>=', '<=', '!=', 'IN', | |||||
| 'LIKE', 'IS', 'EXISTS') | |||||
| UNIT_OPS = ('NONE', '-', '+', '*', '/') | |||||
| AGG_OPS = ('', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG') | |||||
| TABLE_TYPE = { | |||||
| 'sql': 'sql', | |||||
| 'table_unit': 'table_unit', | |||||
| } | |||||
| COND_OPS = ('AND', 'OR') | |||||
| SQL_OPS = ('INTERSECT', 'UNION', 'EXCEPT') | |||||
| ORDER_OPS = ('DESC', 'ASC') | |||||
| def get_select_labels(select, slot, cur_nest): | |||||
| for item in select[1]: | |||||
| if AGG_OPS[item[0]] != '': | |||||
| if slot[item[1][1][1]] == '': | |||||
| slot[item[1][1][1]] += (cur_nest + ' ' + AGG_OPS[item[0]]) | |||||
| else: | |||||
| slot[item[1][1][1]] += (' ' + cur_nest + ' ' | |||||
| + AGG_OPS[item[0]]) | |||||
| else: | |||||
| if slot[item[1][1][1]] == '': | |||||
| slot[item[1][1][1]] += (cur_nest) | |||||
| else: | |||||
| slot[item[1][1][1]] += (' ' + cur_nest) | |||||
| return slot | |||||
| def get_groupby_labels(groupby, slot, cur_nest): | |||||
| for item in groupby: | |||||
| if slot[item[1]] == '': | |||||
| slot[item[1]] += (cur_nest) | |||||
| else: | |||||
| slot[item[1]] += (' ' + cur_nest) | |||||
| return slot | |||||
| def get_orderby_labels(orderby, limit, slot, cur_nest): | |||||
| if limit is None: | |||||
| thelimit = '' | |||||
| else: | |||||
| thelimit = ' LIMIT' | |||||
| for item in orderby[1]: | |||||
| if AGG_OPS[item[1][0]] != '': | |||||
| agg = ' ' + AGG_OPS[item[1][0]] + ' ' | |||||
| else: | |||||
| agg = ' ' | |||||
| if slot[item[1][1]] == '': | |||||
| slot[item[1][1]] += ( | |||||
| cur_nest + agg + orderby[0].upper() + thelimit) | |||||
| else: | |||||
| slot[item[1][1]] += (' ' + cur_nest + agg + orderby[0].upper() | |||||
| + thelimit) | |||||
| return slot | |||||
| def get_intersect_labels(intersect, slot, cur_nest): | |||||
| if isinstance(intersect, dict): | |||||
| if cur_nest != '': | |||||
| slot = get_labels(intersect, slot, cur_nest) | |||||
| else: | |||||
| slot = get_labels(intersect, slot, 'INTERSECT') | |||||
| else: | |||||
| return slot | |||||
| return slot | |||||
| def get_except_labels(texcept, slot, cur_nest): | |||||
| if isinstance(texcept, dict): | |||||
| if cur_nest != '': | |||||
| slot = get_labels(texcept, slot, cur_nest) | |||||
| else: | |||||
| slot = get_labels(texcept, slot, 'EXCEPT') | |||||
| else: | |||||
| return slot | |||||
| return slot | |||||
| def get_union_labels(union, slot, cur_nest): | |||||
| if isinstance(union, dict): | |||||
| if cur_nest != '': | |||||
| slot = get_labels(union, slot, cur_nest) | |||||
| else: | |||||
| slot = get_labels(union, slot, 'UNION') | |||||
| else: | |||||
| return slot | |||||
| return slot | |||||
| def get_from_labels(tfrom, slot, cur_nest): | |||||
| if tfrom['table_units'][0][0] == 'sql': | |||||
| slot = get_labels(tfrom['table_units'][0][1], slot, 'OP_SEL') | |||||
| else: | |||||
| return slot | |||||
| return slot | |||||
| def get_having_labels(having, slot, cur_nest): | |||||
| if len(having) == 1: | |||||
| item = having[0] | |||||
| if item[0] is True: | |||||
| neg = ' NOT' | |||||
| else: | |||||
| neg = '' | |||||
| if isinstance(item[3], dict): | |||||
| if AGG_OPS[item[2][1][0]] != '': | |||||
| agg = ' ' + AGG_OPS[item[2][1][0]] | |||||
| else: | |||||
| agg = '' | |||||
| if slot[item[2][1][1]] == '': | |||||
| slot[item[2][1][1]] += ( | |||||
| cur_nest + agg + neg + ' ' + WHERE_OPS[item[1]]) | |||||
| else: | |||||
| slot[item[2][1][1]] += (' ' + cur_nest + agg + neg + ' ' | |||||
| + WHERE_OPS[item[1]]) | |||||
| slot = get_labels(item[3], slot, 'OP_SEL') | |||||
| else: | |||||
| if AGG_OPS[item[2][1][0]] != '': | |||||
| agg = ' ' + AGG_OPS[item[2][1][0]] + ' ' | |||||
| else: | |||||
| agg = ' ' | |||||
| if slot[item[2][1][1]] == '': | |||||
| slot[item[2][1][1]] += (cur_nest + agg + WHERE_OPS[item[1]]) | |||||
| else: | |||||
| slot[item[2][1][1]] += (' ' + cur_nest + agg | |||||
| + WHERE_OPS[item[1]]) | |||||
| else: | |||||
| for index, item in enumerate(having): | |||||
| if item[0] is True: | |||||
| neg = ' NOT' | |||||
| else: | |||||
| neg = '' | |||||
| if (index + 1 < len(having) and having[index + 1]) == 'or' or ( | |||||
| index - 1 >= 0 and having[index - 1] == 'or'): | |||||
| if AGG_OPS[item[2][1][0]] != '': | |||||
| agg = ' ' + AGG_OPS[item[2][1][0]] | |||||
| else: | |||||
| agg = '' | |||||
| if isinstance(item[3], dict): | |||||
| if slot[item[2][1][1]] == '': | |||||
| slot[item[2][1][1]] += ( | |||||
| cur_nest + agg + neg + ' ' + WHERE_OPS[item[1]]) | |||||
| else: | |||||
| slot[item[2][1][1]] += (' ' + cur_nest + agg + neg | |||||
| + ' ' + WHERE_OPS[item[1]]) | |||||
| slot = get_labels(item[3], slot, 'OP_SEL') | |||||
| else: | |||||
| if AGG_OPS[item[2][1][0]] != '': | |||||
| agg = ' ' + AGG_OPS[item[2][1][0]] + ' ' | |||||
| else: | |||||
| agg = ' ' | |||||
| if slot[item[2][1][1]] == '': | |||||
| slot[item[2][1][1]] += ( | |||||
| cur_nest + ' OR' + agg + WHERE_OPS[item[1]]) | |||||
| else: | |||||
| slot[item[2][1][1]] += (' ' + cur_nest + ' OR' + agg | |||||
| + WHERE_OPS[item[1]]) | |||||
| elif item == 'and' or item == 'or': | |||||
| continue | |||||
| else: | |||||
| if isinstance(item[3], dict): | |||||
| if slot[item[2][1][1]] == '': | |||||
| slot[item[2][1][1]] += ( | |||||
| cur_nest + neg + ' ' + WHERE_OPS[item[1]]) | |||||
| else: | |||||
| slot[item[2][1][1]] += (' ' + cur_nest + neg + ' ' | |||||
| + WHERE_OPS[item[1]]) | |||||
| slot = get_labels(item[3], slot, 'OP_SEL') | |||||
| else: | |||||
| if AGG_OPS[item[2][1][0]] != '': | |||||
| agg = ' ' + AGG_OPS[item[2][1][0]] + ' ' | |||||
| else: | |||||
| agg = ' ' | |||||
| if slot[item[2][1][1]] == '': | |||||
| slot[item[2][1][1]] += ( | |||||
| cur_nest + agg + WHERE_OPS[item[1]]) | |||||
| else: | |||||
| slot[item[2][1][1]] += (' ' + cur_nest + agg | |||||
| + WHERE_OPS[item[1]]) | |||||
| return slot | |||||
| def get_where_labels(where, slot, cur_nest): | |||||
| if len(where) == 1: | |||||
| item = where[0] | |||||
| if item[0] is True: | |||||
| neg = ' NOT' | |||||
| else: | |||||
| neg = '' | |||||
| if isinstance(item[3], dict): | |||||
| if slot[item[2][1][1]] == '': | |||||
| slot[item[2][1][1]] += ( | |||||
| cur_nest + neg + ' ' + WHERE_OPS[item[1]]) | |||||
| else: | |||||
| slot[item[2][1][1]] += (' ' + cur_nest + neg + ' ' | |||||
| + WHERE_OPS[item[1]]) | |||||
| slot = get_labels(item[3], slot, 'OP_SEL') | |||||
| else: | |||||
| if slot[item[2][1][1]] == '': | |||||
| slot[item[2][1][1]] += ( | |||||
| cur_nest + neg + ' ' + WHERE_OPS[item[1]]) | |||||
| else: | |||||
| slot[item[2][1][1]] += (' ' + cur_nest + neg + ' ' | |||||
| + WHERE_OPS[item[1]]) | |||||
| else: | |||||
| for index, item in enumerate(where): | |||||
| if item[0] is True: | |||||
| neg = ' NOT' | |||||
| else: | |||||
| neg = '' | |||||
| if (index + 1 < len(where) and where[index + 1]) == 'or' or ( | |||||
| index - 1 >= 0 and where[index - 1] == 'or'): | |||||
| if isinstance(item[3], dict): | |||||
| if slot[item[2][1][1]] == '': | |||||
| slot[item[2][1][1]] += ( | |||||
| cur_nest + neg + ' ' + WHERE_OPS[item[1]]) | |||||
| else: | |||||
| slot[item[2][1][1]] += (' ' + cur_nest + neg + ' ' | |||||
| + WHERE_OPS[item[1]]) | |||||
| slot = get_labels(item[3], slot, 'OP_SEL') | |||||
| else: | |||||
| if slot[item[2][1][1]] == '': | |||||
| slot[item[2][1][1]] += ( | |||||
| cur_nest + ' OR' + neg + ' ' + WHERE_OPS[item[1]]) | |||||
| else: | |||||
| slot[item[2][1][1]] += (' ' + cur_nest + ' OR' + neg | |||||
| + ' ' + WHERE_OPS[item[1]]) | |||||
| elif item == 'and' or item == 'or': | |||||
| continue | |||||
| else: | |||||
| if isinstance(item[3], dict): | |||||
| if slot[item[2][1][1]] == '': | |||||
| slot[item[2][1][1]] += ( | |||||
| cur_nest + neg + ' ' + WHERE_OPS[item[1]]) | |||||
| else: | |||||
| slot[item[2][1][1]] += (' ' + cur_nest + neg + ' ' | |||||
| + WHERE_OPS[item[1]]) | |||||
| slot = get_labels(item[3], slot, 'OP_SEL') | |||||
| else: | |||||
| if slot[item[2][1][1]] == '': | |||||
| slot[item[2][1][1]] += ( | |||||
| cur_nest + neg + ' ' + WHERE_OPS[item[1]]) | |||||
| else: | |||||
| slot[item[2][1][1]] += (' ' + cur_nest + neg + ' ' | |||||
| + WHERE_OPS[item[1]]) | |||||
| return slot | |||||
| def get_labels(sql_struct, slot, cur_nest): | |||||
| if len(sql_struct['select']) > 0: | |||||
| if cur_nest != '': | |||||
| slot = get_select_labels(sql_struct['select'], slot, | |||||
| cur_nest + ' SELECT') | |||||
| else: | |||||
| slot = get_select_labels(sql_struct['select'], slot, 'SELECT') | |||||
| if sql_struct['from']: | |||||
| if cur_nest != '': | |||||
| slot = get_from_labels(sql_struct['from'], slot, 'FROM') | |||||
| else: | |||||
| slot = get_from_labels(sql_struct['from'], slot, 'FROM') | |||||
| if len(sql_struct['where']) > 0: | |||||
| if cur_nest != '': | |||||
| slot = get_where_labels(sql_struct['where'], slot, | |||||
| cur_nest + ' WHERE') | |||||
| else: | |||||
| slot = get_where_labels(sql_struct['where'], slot, 'WHERE') | |||||
| if len(sql_struct['groupBy']) > 0: | |||||
| if cur_nest != '': | |||||
| slot = get_groupby_labels(sql_struct['groupBy'], slot, | |||||
| cur_nest + ' GROUP_BY') | |||||
| else: | |||||
| slot = get_groupby_labels(sql_struct['groupBy'], slot, 'GROUP_BY') | |||||
| if len(sql_struct['having']) > 0: | |||||
| if cur_nest != '': | |||||
| slot = get_having_labels(sql_struct['having'], slot, | |||||
| cur_nest + ' HAVING') | |||||
| else: | |||||
| slot = get_having_labels(sql_struct['having'], slot, 'HAVING') | |||||
| if len(sql_struct['orderBy']) > 0: | |||||
| if cur_nest != '': | |||||
| slot = get_orderby_labels(sql_struct['orderBy'], | |||||
| sql_struct['limit'], slot, | |||||
| cur_nest + ' ORDER_BY') | |||||
| else: | |||||
| slot = get_orderby_labels(sql_struct['orderBy'], | |||||
| sql_struct['limit'], slot, 'ORDER_BY') | |||||
| if sql_struct['intersect']: | |||||
| if cur_nest != '': | |||||
| slot = get_intersect_labels(sql_struct['intersect'], slot, | |||||
| cur_nest + ' INTERSECT') | |||||
| else: | |||||
| slot = get_intersect_labels(sql_struct['intersect'], slot, | |||||
| 'INTERSECT') | |||||
| if sql_struct['except']: | |||||
| if cur_nest != '': | |||||
| slot = get_except_labels(sql_struct['except'], slot, | |||||
| cur_nest + ' EXCEPT') | |||||
| else: | |||||
| slot = get_except_labels(sql_struct['except'], slot, 'EXCEPT') | |||||
| if sql_struct['union']: | |||||
| if cur_nest != '': | |||||
| slot = get_union_labels(sql_struct['union'], slot, | |||||
| cur_nest + ' UNION') | |||||
| else: | |||||
| slot = get_union_labels(sql_struct['union'], slot, 'UNION') | |||||
| return slot | |||||
| def get_label(sql, column_len): | |||||
| thelabel = [] | |||||
| slot = {} | |||||
| for idx in range(column_len): | |||||
| slot[idx] = '' | |||||
| for value in get_labels(sql, slot, '').values(): | |||||
| thelabel.append(value) | |||||
| return thelabel | |||||
| @@ -0,0 +1,37 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from text2sql_lgesql.preprocess.parse_raw_json import Schema, get_schemas | |||||
| from text2sql_lgesql.process_sql import get_sql | |||||
| from modelscope.preprocessors.star.fields.parse import get_label | |||||
| def preprocess_dataset(processor, dataset, output_tables, database_id, tables): | |||||
| schemas, db_names, thetables = get_schemas(tables) | |||||
| intables = output_tables[database_id] | |||||
| schema = schemas[database_id] | |||||
| table = thetables[database_id] | |||||
| sql_label = [] | |||||
| if len(dataset['history']) == 0 or dataset['last_sql'] == '': | |||||
| sql_label = [''] * len(intables['column_names']) | |||||
| else: | |||||
| schema = Schema(schema, table) | |||||
| try: | |||||
| sql_label = get_sql(schema, dataset['last_sql']) | |||||
| except Exception: | |||||
| sql_label = [''] * len(intables['column_names']) | |||||
| sql_label = get_label(sql_label, len(table['column_names_original'])) | |||||
| theone = {'db_id': database_id} | |||||
| theone['query'] = '' | |||||
| theone['query_toks_no_value'] = [] | |||||
| theone['sql'] = {} | |||||
| if len(dataset['history']) != 0: | |||||
| theone['question'] = dataset['utterance'] + ' [CLS] ' + ' [CLS] '.join( | |||||
| dataset['history'][::-1][:4]) | |||||
| theone['question_toks'] = theone['question'].split() | |||||
| else: | |||||
| theone['question'] = dataset['utterance'] | |||||
| theone['question_toks'] = dataset['utterance'].split() | |||||
| return [theone], sql_label | |||||
| @@ -0,0 +1,64 @@ | |||||
| # Copyright (c) rhythmcao modified from https://github.com/rhythmcao/text2sql-lgesql. | |||||
| import argparse | |||||
| import os | |||||
| import pickle | |||||
| import sys | |||||
| import time | |||||
| import json | |||||
| from text2sql_lgesql.asdl.asdl import ASDLGrammar | |||||
| from text2sql_lgesql.asdl.transition_system import TransitionSystem | |||||
| from modelscope.preprocessors.star.fields.common_utils import SubPreprocessor | |||||
| sys.path.append(os.path.dirname(os.path.dirname(__file__))) | |||||
| def process_example(processor, entry, db, trans, verbose=False): | |||||
| # preprocess raw tokens, schema linking and subgraph extraction | |||||
| entry = processor.pipeline(entry, db, verbose=verbose) | |||||
| # generate target output actions | |||||
| entry['ast'] = [] | |||||
| entry['actions'] = [] | |||||
| return entry | |||||
| def process_tables(processor, tables_list, output_path=None, verbose=False): | |||||
| tables = {} | |||||
| for each in tables_list: | |||||
| if verbose: | |||||
| print('*************** Processing database %s **************' % | |||||
| (each['db_id'])) | |||||
| tables[each['db_id']] = processor.preprocess_database( | |||||
| each, verbose=verbose) | |||||
| print('In total, process %d databases .' % (len(tables))) | |||||
| if output_path is not None: | |||||
| pickle.dump(tables, open(output_path, 'wb')) | |||||
| return tables | |||||
| def process_dataset(model_dir, | |||||
| processor, | |||||
| dataset, | |||||
| tables, | |||||
| output_path=None, | |||||
| skip_large=False, | |||||
| verbose=False): | |||||
| grammar = ASDLGrammar.from_filepath( | |||||
| os.path.join(model_dir, 'sql_asdl_v2.txt')) | |||||
| trans = TransitionSystem.get_class_by_lang('sql')(grammar) | |||||
| processed_dataset = [] | |||||
| for idx, entry in enumerate(dataset): | |||||
| if skip_large and len(tables[entry['db_id']]['column_names']) > 100: | |||||
| continue | |||||
| if verbose: | |||||
| print('*************** Processing %d-th sample **************' % | |||||
| (idx)) | |||||
| entry = process_example( | |||||
| processor, entry, tables[entry['db_id']], trans, verbose=verbose) | |||||
| processed_dataset.append(entry) | |||||
| if output_path is not None: | |||||
| # serialize preprocessed dataset | |||||
| pickle.dump(processed_dataset, open(output_path, 'wb')) | |||||
| return processed_dataset | |||||
| @@ -89,6 +89,7 @@ class NLPTasks(object): | |||||
| zero_shot_classification = 'zero-shot-classification' | zero_shot_classification = 'zero-shot-classification' | ||||
| backbone = 'backbone' | backbone = 'backbone' | ||||
| text_error_correction = 'text-error-correction' | text_error_correction = 'text-error-correction' | ||||
| conversational_text_to_sql = 'conversational-text-to-sql' | |||||
| class AudioTasks(object): | class AudioTasks(object): | ||||
| @@ -6,5 +6,6 @@ pai-easynlp | |||||
| rouge_score<=0.0.4 | rouge_score<=0.0.4 | ||||
| seqeval | seqeval | ||||
| spacy>=2.3.5 | spacy>=2.3.5 | ||||
| text2sql_lgesql | |||||
| tokenizers | tokenizers | ||||
| transformers>=4.12.0 | transformers>=4.12.0 | ||||
| @@ -0,0 +1,97 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from typing import List | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.nlp import StarForTextToSql | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.pipelines.nlp import ConversationalTextToSqlPipeline | |||||
| from modelscope.preprocessors import ConversationalTextToSqlPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class ConversationalTextToSql(unittest.TestCase): | |||||
| model_id = 'damo/nlp_star_conversational-text-to-sql' | |||||
| test_case = { | |||||
| 'database_id': | |||||
| 'employee_hire_evaluation', | |||||
| 'local_db_path': | |||||
| None, | |||||
| 'utterance': [ | |||||
| "I'd like to see Shop names.", 'Which of these are hiring?', | |||||
| 'Which shop is hiring the highest number of employees? | do you want the name of the shop ? | Yes' | |||||
| ] | |||||
| } | |||||
| def tracking_and_print_results( | |||||
| self, pipelines: List[ConversationalTextToSqlPipeline]): | |||||
| for my_pipeline in pipelines: | |||||
| last_sql, history = '', [] | |||||
| for item in self.test_case['utterance']: | |||||
| case = { | |||||
| 'utterance': item, | |||||
| 'history': history, | |||||
| 'last_sql': last_sql, | |||||
| 'database_id': self.test_case['database_id'], | |||||
| 'local_db_path': self.test_case['local_db_path'] | |||||
| } | |||||
| results = my_pipeline(case) | |||||
| print({'question': item}) | |||||
| print(results) | |||||
| last_sql = results['text'] | |||||
| history.append(item) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_by_direct_model_download(self): | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| preprocessor = ConversationalTextToSqlPreprocessor( | |||||
| model_dir=cache_path, | |||||
| database_id=self.test_case['database_id'], | |||||
| db_content=True) | |||||
| model = StarForTextToSql( | |||||
| model_dir=cache_path, config=preprocessor.config) | |||||
| pipelines = [ | |||||
| ConversationalTextToSqlPipeline( | |||||
| model=model, preprocessor=preprocessor), | |||||
| pipeline( | |||||
| task=Tasks.conversational_text_to_sql, | |||||
| model=model, | |||||
| preprocessor=preprocessor) | |||||
| ] | |||||
| self.tracking_and_print_results(pipelines) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| preprocessor = ConversationalTextToSqlPreprocessor( | |||||
| model_dir=model.model_dir) | |||||
| pipelines = [ | |||||
| ConversationalTextToSqlPipeline( | |||||
| model=model, preprocessor=preprocessor), | |||||
| pipeline( | |||||
| task=Tasks.conversational_text_to_sql, | |||||
| model=model, | |||||
| preprocessor=preprocessor) | |||||
| ] | |||||
| self.tracking_and_print_results(pipelines) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_name(self): | |||||
| pipelines = [ | |||||
| pipeline( | |||||
| task=Tasks.conversational_text_to_sql, model=self.model_id) | |||||
| ] | |||||
| self.tracking_and_print_results(pipelines) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_default_model(self): | |||||
| pipelines = [pipeline(task=Tasks.conversational_text_to_sql)] | |||||
| self.tracking_and_print_results(pipelines) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||