diff --git a/modelscope/models/nlp/table_question_answering.py b/modelscope/models/nlp/table_question_answering.py index 3c91a518..c6a03ef3 100644 --- a/modelscope/models/nlp/table_question_answering.py +++ b/modelscope/models/nlp/table_question_answering.py @@ -3,9 +3,11 @@ import os from typing import Dict +import json import numpy import torch import torch.nn.functional as F +import tqdm from transformers import BertTokenizer from modelscope.metainfo import Models @@ -82,7 +84,6 @@ class TableQuestionAnswering(Model): if ntok.startswith('##'): ntok = ntok.replace('##', '') - tok = nlu1[idx:idx + 1].lower() if ntok == tok: conv_dict[i] = [idx, idx + 1] diff --git a/modelscope/outputs.py b/modelscope/outputs.py index dd59d6fb..0f353d3d 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -37,6 +37,7 @@ class OutputKeys(object): WORD = 'word' KWS_LIST = 'kws_list' HISTORY = 'history' + QUERT_RESULT = 'query_result' TIMESTAMPS = 'timestamps' SHOT_NUM = 'shot_num' SCENE_NUM = 'scene_num' diff --git a/modelscope/pipelines/nlp/table_question_answering_pipeline.py b/modelscope/pipelines/nlp/table_question_answering_pipeline.py index 96bfbc34..e1b2b07b 100644 --- a/modelscope/pipelines/nlp/table_question_answering_pipeline.py +++ b/modelscope/pipelines/nlp/table_question_answering_pipeline.py @@ -2,6 +2,8 @@ import os from typing import Any, Dict, Union +import json +import torch from transformers import BertTokenizer from modelscope.metainfo import Pipelines @@ -230,14 +232,16 @@ class TableQuestionAnsweringPipeline(Pipeline): str_sel_list.append(header_name) sql_sel_list.append(header_id) else: - str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '( ' - + header_name + ' )') - sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '( ' - + header_id + ' )') + str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' + + header_name + ')') + sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' + + header_id + ')') str_cond_list, sql_cond_list = [], [] for cond in sql['conds']: header_name = header_names[cond[0]] + if header_name == '空列': + continue header_id = '`%s`.`%s`' % (table['table_id'], header_ids[cond[0]]) op = self.cond_ops[cond[1]] value = cond[2] @@ -248,12 +252,17 @@ class TableQuestionAnsweringPipeline(Pipeline): cond = ' ' + self.cond_conn_ops[sql['cond_conn_op']] + ' ' - final_str = 'SELECT %s FROM %s WHERE %s' % (', '.join(str_sel_list), - table['table_name'], - cond.join(str_cond_list)) - final_sql = 'SELECT %s FROM `%s` WHERE %s' % (', '.join(sql_sel_list), - table['table_id'], - cond.join(sql_cond_list)) + if len(str_cond_list) != 0: + final_str = 'SELECT %s FROM %s WHERE %s' % (', '.join( + str_sel_list), table['table_name'], cond.join(str_cond_list)) + final_sql = 'SELECT %s FROM `%s` WHERE %s' % (', '.join( + sql_sel_list), table['table_id'], cond.join(sql_cond_list)) + else: + final_str = 'SELECT %s FROM %s' % (', '.join(str_sel_list), + table['table_name']) + final_sql = 'SELECT %s FROM `%s`' % (', '.join(sql_sel_list), + table['table_id']) + sql = SQLQuery( string=final_str, query=final_sql, sql_result=result['sql']) @@ -274,9 +283,39 @@ class TableQuestionAnsweringPipeline(Pipeline): history_sql=history_sql, result=result, table=self.db.tables[result['table_id']]) + result['sql']['from'] = [result['table_id']] sql = self.sql_dict_to_str( result=result, table=self.db.tables[result['table_id']]) - output = {OutputKeys.OUTPUT: sql, OutputKeys.HISTORY: result['sql']} + + # add sqlite + if self.db.is_use_sqlite: + try: + cursor = self.db.connection_obj.cursor().execute(sql.query) + names = [{ + 'name': + description[0], + 'label': + self.db.tables[result['table_id']]['headerid2name'].get( + description[0], description[0]) + } for description in cursor.description] + cells = [] + for res in cursor.fetchall(): + row = {} + for name, cell in zip(names, res): + row[name['name']] = cell + cells.append(row) + tabledata = {'headers': names, 'cells': cells} + except Exception: + tabledata = {'headers': [], 'cells': []} + else: + tabledata = {'headers': [], 'cells': []} + + output = { + OutputKeys.OUTPUT: sql, + OutputKeys.HISTORY: result['sql'], + OutputKeys.QUERT_RESULT: json.dumps(tabledata, ensure_ascii=False), + } + return output def _collate_fn(self, data): diff --git a/modelscope/preprocessors/star3/fields/database.py b/modelscope/preprocessors/star3/fields/database.py index a99800cf..3d3a1f8d 100644 --- a/modelscope/preprocessors/star3/fields/database.py +++ b/modelscope/preprocessors/star3/fields/database.py @@ -1,4 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import sqlite3 + import json import tqdm @@ -7,18 +9,38 @@ from modelscope.preprocessors.star3.fields.struct import Trie class Database: - def __init__(self, tokenizer, table_file_path, syn_dict_file_path): + def __init__(self, + tokenizer, + table_file_path, + syn_dict_file_path, + is_use_sqlite=False): self.tokenizer = tokenizer + self.is_use_sqlite = is_use_sqlite + if self.is_use_sqlite: + self.connection_obj = sqlite3.connect(':memory:') + self.type_dict = {'text': 'TEXT', 'number': 'INT', 'date': 'TEXT'} self.tables = self.init_tables(table_file_path=table_file_path) self.syn_dict = self.init_syn_dict( syn_dict_file_path=syn_dict_file_path) + def __del__(self): + if self.is_use_sqlite: + self.connection_obj.close() + def init_tables(self, table_file_path): tables = {} lines = [] - with open(table_file_path, 'r') as fo: - for line in fo: - lines.append(line) + if type(table_file_path) == str: + with open(table_file_path, 'r') as fo: + for line in fo: + lines.append(line) + elif type(table_file_path) == list: + for path in table_file_path: + with open(path, 'r') as fo: + for line in fo: + lines.append(line) + else: + raise ValueError() for line in tqdm.tqdm(lines, desc='Load Tables'): table = json.loads(line.strip()) @@ -34,6 +56,9 @@ class Database: headers_tokens.append(empty_column) table['tablelen'] = table_header_length table['header_tok'] = headers_tokens + table['headerid2name'] = {} + for hid, hname in zip(table['header_id'], table['header_name']): + table['headerid2name'][hid] = hname table['header_types'].append('null') table['header_units'] = [ @@ -51,6 +76,26 @@ class Database: trie_set[ii].insert(word, word) table['value_trie'] = trie_set + + # create sqlite + if self.is_use_sqlite: + cursor_obj = self.connection_obj.cursor() + cursor_obj.execute('DROP TABLE IF EXISTS %s' % + (table['table_id'])) + header_string = ', '.join([ + '%s %s' % + (name, self.type_dict[htype]) for name, htype in zip( + table['header_id'], table['header_types']) + ]) + create_table_string = 'CREATE TABLE %s (%s);' % ( + table['table_id'], header_string) + cursor_obj.execute(create_table_string) + for row in table['rows']: + value_string = ', '.join(['"%s"' % (val) for val in row]) + insert_row_string = 'INSERT INTO %s VALUES(%s)' % ( + table['table_id'], value_string) + cursor_obj.execute(insert_row_string) + tables[table['table_id']] = table return tables diff --git a/modelscope/preprocessors/star3/fields/schema_link.py b/modelscope/preprocessors/star3/fields/schema_link.py index 40613f78..7f483a1f 100644 --- a/modelscope/preprocessors/star3/fields/schema_link.py +++ b/modelscope/preprocessors/star3/fields/schema_link.py @@ -287,7 +287,13 @@ class SchemaLinker: return match_len / (len(nlu_t) + 0.1) - def get_entity_linking(self, tokenizer, nlu, nlu_t, tables, col_syn_dict): + def get_entity_linking(self, + tokenizer, + nlu, + nlu_t, + tables, + col_syn_dict, + history_sql=None): """ get linking between question and schema column """ @@ -305,8 +311,7 @@ class SchemaLinker: typeinfos = [] for ii, column in enumerate(table['header_name']): column = column.lower() - column_new = re.sub('(.*?)', '', column) - column_new = re.sub('(.*?)', '', column_new) + column_new = column cphrase, cscore = self.get_match_phrase( nlu.lower(), column_new) if cscore > 0.3 and cphrase.strip() != '': @@ -330,7 +335,6 @@ class SchemaLinker: for cell in ans.keys(): vphrase = cell vscore = 1.0 - # print("trie_set find:", cell, ans[cell]) phrase_tok = tokenizer.tokenize(vphrase) if len(phrase_tok) == 0 or len(vphrase) < 2: continue @@ -408,16 +412,25 @@ class SchemaLinker: match_score = self.get_table_match_score(nlu_t, schema_link) search_result = { - 'table_id': table['table_id'], - 'question_knowledge': final_question, - 'header_knowledge': final_header, - 'schema_link': schema_link, - 'match_score': match_score + 'table_id': + table['table_id'], + 'question_knowledge': + final_question, + 'header_knowledge': + final_header, + 'schema_link': + schema_link, + 'match_score': + match_score, + 'table_score': + int(table['table_id'] == history_sql['from'][0]) + if history_sql is not None else 0 } search_result_list.append(search_result) search_result_list = sorted( - search_result_list, key=lambda x: x['match_score'], + search_result_list, + key=lambda x: (x['match_score'], x['table_score']), reverse=True)[0:4] return search_result_list diff --git a/modelscope/preprocessors/star3/table_question_answering_preprocessor.py b/modelscope/preprocessors/star3/table_question_answering_preprocessor.py index 163759a1..f98aa6d0 100644 --- a/modelscope/preprocessors/star3/table_question_answering_preprocessor.py +++ b/modelscope/preprocessors/star3/table_question_answering_preprocessor.py @@ -95,7 +95,7 @@ class TableQuestionAnsweringPreprocessor(Preprocessor): # tokenize question question = data['question'] - history_sql = data['history_sql'] + history_sql = data.get('history_sql', None) nlu = question.lower() nlu_t = self.tokenizer.tokenize(nlu) @@ -105,7 +105,8 @@ class TableQuestionAnsweringPreprocessor(Preprocessor): nlu=nlu, nlu_t=nlu_t, tables=self.db.tables, - col_syn_dict=self.db.syn_dict) + col_syn_dict=self.db.syn_dict, + history_sql=history_sql) # collect data datas = self.construct_data( diff --git a/modelscope/utils/nlp/nlp_utils.py b/modelscope/utils/nlp/nlp_utils.py index eba12103..35b374f2 100644 --- a/modelscope/utils/nlp/nlp_utils.py +++ b/modelscope/utils/nlp/nlp_utils.py @@ -2,8 +2,7 @@ from typing import List from modelscope.outputs import OutputKeys from modelscope.pipelines.nlp import (ConversationalTextToSqlPipeline, - DialogStateTrackingPipeline, - TableQuestionAnsweringPipeline) + DialogStateTrackingPipeline) def text2sql_tracking_and_print_results( @@ -42,17 +41,3 @@ def tracking_and_print_dialog_states( print(json.dumps(result)) history_states.extend([result[OutputKeys.OUTPUT], {}]) - - -def tableqa_tracking_and_print_results( - test_case, pipelines: List[TableQuestionAnsweringPipeline]): - for pipeline in pipelines: - historical_queries = None - for question in test_case['utterance']: - output_dict = pipeline({ - 'question': question, - 'history_sql': historical_queries - }) - print('output_dict', output_dict['output'].string, - output_dict['output'].query) - historical_queries = output_dict['history'] diff --git a/tests/pipelines/test_table_question_answering.py b/tests/pipelines/test_table_question_answering.py index 7ea28725..68e0564f 100644 --- a/tests/pipelines/test_table_question_answering.py +++ b/tests/pipelines/test_table_question_answering.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os import unittest +from typing import List from transformers import BertTokenizer @@ -11,10 +12,60 @@ from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline from modelscope.preprocessors import TableQuestionAnsweringPreprocessor from modelscope.preprocessors.star3.fields.database import Database from modelscope.utils.constant import ModelFile, Tasks -from modelscope.utils.nlp.nlp_utils import tableqa_tracking_and_print_results from modelscope.utils.test_utils import test_level +def tableqa_tracking_and_print_results_with_history( + pipelines: List[TableQuestionAnsweringPipeline]): + test_case = { + 'utterance': [ + '有哪些风险类型?', + '风险类型有多少种?', + '珠江流域的小(2)型水库的库容总量是多少?', + '那平均值是多少?', + '那水库的名称呢?', + '换成中型的呢?', + '枣庄营业厅的电话', + '那地址呢?', + '枣庄营业厅的电话和地址', + ] + } + for p in pipelines: + historical_queries = None + for question in test_case['utterance']: + output_dict = p({ + 'question': question, + 'history_sql': historical_queries + }) + print('question', question) + print('sql text:', output_dict['output'].string) + print('sql query:', output_dict['output'].query) + print('query result:', output_dict['query_result']) + print() + historical_queries = output_dict['history'] + + +def tableqa_tracking_and_print_results_without_history( + pipelines: List[TableQuestionAnsweringPipeline]): + test_case = { + 'utterance': [ + '有哪些风险类型?', + '风险类型有多少种?', + '珠江流域的小(2)型水库的库容总量是多少?', + '枣庄营业厅的电话', + '枣庄营业厅的电话和地址', + ] + } + for p in pipelines: + for question in test_case['utterance']: + output_dict = p({'question': question}) + print('question', question) + print('sql text:', output_dict['output'].string) + print('sql query:', output_dict['output'].query) + print('query result:', output_dict['query_result']) + print() + + class TableQuestionAnswering(unittest.TestCase): def setUp(self) -> None: @@ -22,20 +73,18 @@ class TableQuestionAnswering(unittest.TestCase): self.model_id = 'damo/nlp_convai_text2sql_pretrain_cn' model_id = 'damo/nlp_convai_text2sql_pretrain_cn' - test_case = { - 'utterance': - ['长江流域的小(2)型水库的库容总量是多少?', '那平均值是多少?', '那水库的名称呢?', '换成中型的呢?'] - } @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_by_direct_model_download(self): cache_path = snapshot_download(self.model_id) preprocessor = TableQuestionAnsweringPreprocessor(model_dir=cache_path) pipelines = [ - TableQuestionAnsweringPipeline( - model=cache_path, preprocessor=preprocessor) + pipeline( + Tasks.table_question_answering, + model=cache_path, + preprocessor=preprocessor) ] - tableqa_tracking_and_print_results(self.test_case, pipelines) + tableqa_tracking_and_print_results_with_history(pipelines) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_model_from_modelhub(self): @@ -43,15 +92,17 @@ class TableQuestionAnswering(unittest.TestCase): preprocessor = TableQuestionAnsweringPreprocessor( model_dir=model.model_dir) pipelines = [ - TableQuestionAnsweringPipeline( - model=model, preprocessor=preprocessor) + pipeline( + Tasks.table_question_answering, + model=model, + preprocessor=preprocessor) ] - tableqa_tracking_and_print_results(self.test_case, pipelines) + tableqa_tracking_and_print_results_with_history(pipelines) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_from_task(self): pipelines = [pipeline(Tasks.table_question_answering, self.model_id)] - tableqa_tracking_and_print_results(self.test_case, pipelines) + tableqa_tracking_and_print_results_with_history(pipelines) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_model_from_modelhub_with_other_classes(self): @@ -60,15 +111,24 @@ class TableQuestionAnswering(unittest.TestCase): os.path.join(model.model_dir, ModelFile.VOCAB_FILE)) db = Database( tokenizer=self.tokenizer, - table_file_path=os.path.join(model.model_dir, 'table.json'), - syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt')) + table_file_path=[ + os.path.join(model.model_dir, 'databases', fname) + for fname in os.listdir( + os.path.join(model.model_dir, 'databases')) + ], + syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'), + is_use_sqlite=True) preprocessor = TableQuestionAnsweringPreprocessor( model_dir=model.model_dir, db=db) pipelines = [ - TableQuestionAnsweringPipeline( - model=model, preprocessor=preprocessor, db=db) + pipeline( + Tasks.table_question_answering, + model=model, + preprocessor=preprocessor, + db=db) ] - tableqa_tracking_and_print_results(self.test_case, pipelines) + tableqa_tracking_and_print_results_without_history(pipelines) + tableqa_tracking_and_print_results_with_history(pipelines) if __name__ == '__main__':