相比于master上的tableqa,做出了如下修复:
1. 修复了schema linking中的问题。
2. 同时设置了有history sql和没有history sql的两种输入
3. 增加了sqlite执行逻辑,可以返回sql执行结果
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10365114
master
| @@ -3,9 +3,11 @@ | |||
| import os | |||
| from typing import Dict | |||
| import json | |||
| import numpy | |||
| import torch | |||
| import torch.nn.functional as F | |||
| import tqdm | |||
| from transformers import BertTokenizer | |||
| from modelscope.metainfo import Models | |||
| @@ -82,7 +84,6 @@ class TableQuestionAnswering(Model): | |||
| if ntok.startswith('##'): | |||
| ntok = ntok.replace('##', '') | |||
| tok = nlu1[idx:idx + 1].lower() | |||
| if ntok == tok: | |||
| conv_dict[i] = [idx, idx + 1] | |||
| @@ -37,6 +37,7 @@ class OutputKeys(object): | |||
| WORD = 'word' | |||
| KWS_LIST = 'kws_list' | |||
| HISTORY = 'history' | |||
| QUERT_RESULT = 'query_result' | |||
| TIMESTAMPS = 'timestamps' | |||
| SHOT_NUM = 'shot_num' | |||
| SCENE_NUM = 'scene_num' | |||
| @@ -2,6 +2,8 @@ | |||
| import os | |||
| from typing import Any, Dict, Union | |||
| import json | |||
| import torch | |||
| from transformers import BertTokenizer | |||
| from modelscope.metainfo import Pipelines | |||
| @@ -230,14 +232,16 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||
| str_sel_list.append(header_name) | |||
| sql_sel_list.append(header_id) | |||
| else: | |||
| str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '( ' | |||
| + header_name + ' )') | |||
| sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '( ' | |||
| + header_id + ' )') | |||
| str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' | |||
| + header_name + ')') | |||
| sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' | |||
| + header_id + ')') | |||
| str_cond_list, sql_cond_list = [], [] | |||
| for cond in sql['conds']: | |||
| header_name = header_names[cond[0]] | |||
| if header_name == '空列': | |||
| continue | |||
| header_id = '`%s`.`%s`' % (table['table_id'], header_ids[cond[0]]) | |||
| op = self.cond_ops[cond[1]] | |||
| value = cond[2] | |||
| @@ -248,12 +252,17 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||
| cond = ' ' + self.cond_conn_ops[sql['cond_conn_op']] + ' ' | |||
| final_str = 'SELECT %s FROM %s WHERE %s' % (', '.join(str_sel_list), | |||
| table['table_name'], | |||
| cond.join(str_cond_list)) | |||
| final_sql = 'SELECT %s FROM `%s` WHERE %s' % (', '.join(sql_sel_list), | |||
| table['table_id'], | |||
| cond.join(sql_cond_list)) | |||
| if len(str_cond_list) != 0: | |||
| final_str = 'SELECT %s FROM %s WHERE %s' % (', '.join( | |||
| str_sel_list), table['table_name'], cond.join(str_cond_list)) | |||
| final_sql = 'SELECT %s FROM `%s` WHERE %s' % (', '.join( | |||
| sql_sel_list), table['table_id'], cond.join(sql_cond_list)) | |||
| else: | |||
| final_str = 'SELECT %s FROM %s' % (', '.join(str_sel_list), | |||
| table['table_name']) | |||
| final_sql = 'SELECT %s FROM `%s`' % (', '.join(sql_sel_list), | |||
| table['table_id']) | |||
| sql = SQLQuery( | |||
| string=final_str, query=final_sql, sql_result=result['sql']) | |||
| @@ -274,9 +283,39 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||
| history_sql=history_sql, | |||
| result=result, | |||
| table=self.db.tables[result['table_id']]) | |||
| result['sql']['from'] = [result['table_id']] | |||
| sql = self.sql_dict_to_str( | |||
| result=result, table=self.db.tables[result['table_id']]) | |||
| output = {OutputKeys.OUTPUT: sql, OutputKeys.HISTORY: result['sql']} | |||
| # add sqlite | |||
| if self.db.is_use_sqlite: | |||
| try: | |||
| cursor = self.db.connection_obj.cursor().execute(sql.query) | |||
| names = [{ | |||
| 'name': | |||
| description[0], | |||
| 'label': | |||
| self.db.tables[result['table_id']]['headerid2name'].get( | |||
| description[0], description[0]) | |||
| } for description in cursor.description] | |||
| cells = [] | |||
| for res in cursor.fetchall(): | |||
| row = {} | |||
| for name, cell in zip(names, res): | |||
| row[name['name']] = cell | |||
| cells.append(row) | |||
| tabledata = {'headers': names, 'cells': cells} | |||
| except Exception: | |||
| tabledata = {'headers': [], 'cells': []} | |||
| else: | |||
| tabledata = {'headers': [], 'cells': []} | |||
| output = { | |||
| OutputKeys.OUTPUT: sql, | |||
| OutputKeys.HISTORY: result['sql'], | |||
| OutputKeys.QUERT_RESULT: json.dumps(tabledata, ensure_ascii=False), | |||
| } | |||
| return output | |||
| def _collate_fn(self, data): | |||
| @@ -1,4 +1,6 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import sqlite3 | |||
| import json | |||
| import tqdm | |||
| @@ -7,18 +9,38 @@ from modelscope.preprocessors.star3.fields.struct import Trie | |||
| class Database: | |||
| def __init__(self, tokenizer, table_file_path, syn_dict_file_path): | |||
| def __init__(self, | |||
| tokenizer, | |||
| table_file_path, | |||
| syn_dict_file_path, | |||
| is_use_sqlite=False): | |||
| self.tokenizer = tokenizer | |||
| self.is_use_sqlite = is_use_sqlite | |||
| if self.is_use_sqlite: | |||
| self.connection_obj = sqlite3.connect(':memory:') | |||
| self.type_dict = {'text': 'TEXT', 'number': 'INT', 'date': 'TEXT'} | |||
| self.tables = self.init_tables(table_file_path=table_file_path) | |||
| self.syn_dict = self.init_syn_dict( | |||
| syn_dict_file_path=syn_dict_file_path) | |||
| def __del__(self): | |||
| if self.is_use_sqlite: | |||
| self.connection_obj.close() | |||
| def init_tables(self, table_file_path): | |||
| tables = {} | |||
| lines = [] | |||
| with open(table_file_path, 'r') as fo: | |||
| for line in fo: | |||
| lines.append(line) | |||
| if type(table_file_path) == str: | |||
| with open(table_file_path, 'r') as fo: | |||
| for line in fo: | |||
| lines.append(line) | |||
| elif type(table_file_path) == list: | |||
| for path in table_file_path: | |||
| with open(path, 'r') as fo: | |||
| for line in fo: | |||
| lines.append(line) | |||
| else: | |||
| raise ValueError() | |||
| for line in tqdm.tqdm(lines, desc='Load Tables'): | |||
| table = json.loads(line.strip()) | |||
| @@ -34,6 +56,9 @@ class Database: | |||
| headers_tokens.append(empty_column) | |||
| table['tablelen'] = table_header_length | |||
| table['header_tok'] = headers_tokens | |||
| table['headerid2name'] = {} | |||
| for hid, hname in zip(table['header_id'], table['header_name']): | |||
| table['headerid2name'][hid] = hname | |||
| table['header_types'].append('null') | |||
| table['header_units'] = [ | |||
| @@ -51,6 +76,26 @@ class Database: | |||
| trie_set[ii].insert(word, word) | |||
| table['value_trie'] = trie_set | |||
| # create sqlite | |||
| if self.is_use_sqlite: | |||
| cursor_obj = self.connection_obj.cursor() | |||
| cursor_obj.execute('DROP TABLE IF EXISTS %s' % | |||
| (table['table_id'])) | |||
| header_string = ', '.join([ | |||
| '%s %s' % | |||
| (name, self.type_dict[htype]) for name, htype in zip( | |||
| table['header_id'], table['header_types']) | |||
| ]) | |||
| create_table_string = 'CREATE TABLE %s (%s);' % ( | |||
| table['table_id'], header_string) | |||
| cursor_obj.execute(create_table_string) | |||
| for row in table['rows']: | |||
| value_string = ', '.join(['"%s"' % (val) for val in row]) | |||
| insert_row_string = 'INSERT INTO %s VALUES(%s)' % ( | |||
| table['table_id'], value_string) | |||
| cursor_obj.execute(insert_row_string) | |||
| tables[table['table_id']] = table | |||
| return tables | |||
| @@ -287,7 +287,13 @@ class SchemaLinker: | |||
| return match_len / (len(nlu_t) + 0.1) | |||
| def get_entity_linking(self, tokenizer, nlu, nlu_t, tables, col_syn_dict): | |||
| def get_entity_linking(self, | |||
| tokenizer, | |||
| nlu, | |||
| nlu_t, | |||
| tables, | |||
| col_syn_dict, | |||
| history_sql=None): | |||
| """ | |||
| get linking between question and schema column | |||
| """ | |||
| @@ -305,8 +311,7 @@ class SchemaLinker: | |||
| typeinfos = [] | |||
| for ii, column in enumerate(table['header_name']): | |||
| column = column.lower() | |||
| column_new = re.sub('(.*?)', '', column) | |||
| column_new = re.sub('(.*?)', '', column_new) | |||
| column_new = column | |||
| cphrase, cscore = self.get_match_phrase( | |||
| nlu.lower(), column_new) | |||
| if cscore > 0.3 and cphrase.strip() != '': | |||
| @@ -330,7 +335,6 @@ class SchemaLinker: | |||
| for cell in ans.keys(): | |||
| vphrase = cell | |||
| vscore = 1.0 | |||
| # print("trie_set find:", cell, ans[cell]) | |||
| phrase_tok = tokenizer.tokenize(vphrase) | |||
| if len(phrase_tok) == 0 or len(vphrase) < 2: | |||
| continue | |||
| @@ -408,16 +412,25 @@ class SchemaLinker: | |||
| match_score = self.get_table_match_score(nlu_t, schema_link) | |||
| search_result = { | |||
| 'table_id': table['table_id'], | |||
| 'question_knowledge': final_question, | |||
| 'header_knowledge': final_header, | |||
| 'schema_link': schema_link, | |||
| 'match_score': match_score | |||
| 'table_id': | |||
| table['table_id'], | |||
| 'question_knowledge': | |||
| final_question, | |||
| 'header_knowledge': | |||
| final_header, | |||
| 'schema_link': | |||
| schema_link, | |||
| 'match_score': | |||
| match_score, | |||
| 'table_score': | |||
| int(table['table_id'] == history_sql['from'][0]) | |||
| if history_sql is not None else 0 | |||
| } | |||
| search_result_list.append(search_result) | |||
| search_result_list = sorted( | |||
| search_result_list, key=lambda x: x['match_score'], | |||
| search_result_list, | |||
| key=lambda x: (x['match_score'], x['table_score']), | |||
| reverse=True)[0:4] | |||
| return search_result_list | |||
| @@ -95,7 +95,7 @@ class TableQuestionAnsweringPreprocessor(Preprocessor): | |||
| # tokenize question | |||
| question = data['question'] | |||
| history_sql = data['history_sql'] | |||
| history_sql = data.get('history_sql', None) | |||
| nlu = question.lower() | |||
| nlu_t = self.tokenizer.tokenize(nlu) | |||
| @@ -105,7 +105,8 @@ class TableQuestionAnsweringPreprocessor(Preprocessor): | |||
| nlu=nlu, | |||
| nlu_t=nlu_t, | |||
| tables=self.db.tables, | |||
| col_syn_dict=self.db.syn_dict) | |||
| col_syn_dict=self.db.syn_dict, | |||
| history_sql=history_sql) | |||
| # collect data | |||
| datas = self.construct_data( | |||
| @@ -2,8 +2,7 @@ from typing import List | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.nlp import (ConversationalTextToSqlPipeline, | |||
| DialogStateTrackingPipeline, | |||
| TableQuestionAnsweringPipeline) | |||
| DialogStateTrackingPipeline) | |||
| def text2sql_tracking_and_print_results( | |||
| @@ -42,17 +41,3 @@ def tracking_and_print_dialog_states( | |||
| print(json.dumps(result)) | |||
| history_states.extend([result[OutputKeys.OUTPUT], {}]) | |||
| def tableqa_tracking_and_print_results( | |||
| test_case, pipelines: List[TableQuestionAnsweringPipeline]): | |||
| for pipeline in pipelines: | |||
| historical_queries = None | |||
| for question in test_case['utterance']: | |||
| output_dict = pipeline({ | |||
| 'question': question, | |||
| 'history_sql': historical_queries | |||
| }) | |||
| print('output_dict', output_dict['output'].string, | |||
| output_dict['output'].query) | |||
| historical_queries = output_dict['history'] | |||
| @@ -1,6 +1,7 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import unittest | |||
| from typing import List | |||
| from transformers import BertTokenizer | |||
| @@ -11,10 +12,60 @@ from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline | |||
| from modelscope.preprocessors import TableQuestionAnsweringPreprocessor | |||
| from modelscope.preprocessors.star3.fields.database import Database | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.nlp.nlp_utils import tableqa_tracking_and_print_results | |||
| from modelscope.utils.test_utils import test_level | |||
| def tableqa_tracking_and_print_results_with_history( | |||
| pipelines: List[TableQuestionAnsweringPipeline]): | |||
| test_case = { | |||
| 'utterance': [ | |||
| '有哪些风险类型?', | |||
| '风险类型有多少种?', | |||
| '珠江流域的小(2)型水库的库容总量是多少?', | |||
| '那平均值是多少?', | |||
| '那水库的名称呢?', | |||
| '换成中型的呢?', | |||
| '枣庄营业厅的电话', | |||
| '那地址呢?', | |||
| '枣庄营业厅的电话和地址', | |||
| ] | |||
| } | |||
| for p in pipelines: | |||
| historical_queries = None | |||
| for question in test_case['utterance']: | |||
| output_dict = p({ | |||
| 'question': question, | |||
| 'history_sql': historical_queries | |||
| }) | |||
| print('question', question) | |||
| print('sql text:', output_dict['output'].string) | |||
| print('sql query:', output_dict['output'].query) | |||
| print('query result:', output_dict['query_result']) | |||
| print() | |||
| historical_queries = output_dict['history'] | |||
| def tableqa_tracking_and_print_results_without_history( | |||
| pipelines: List[TableQuestionAnsweringPipeline]): | |||
| test_case = { | |||
| 'utterance': [ | |||
| '有哪些风险类型?', | |||
| '风险类型有多少种?', | |||
| '珠江流域的小(2)型水库的库容总量是多少?', | |||
| '枣庄营业厅的电话', | |||
| '枣庄营业厅的电话和地址', | |||
| ] | |||
| } | |||
| for p in pipelines: | |||
| for question in test_case['utterance']: | |||
| output_dict = p({'question': question}) | |||
| print('question', question) | |||
| print('sql text:', output_dict['output'].string) | |||
| print('sql query:', output_dict['output'].query) | |||
| print('query result:', output_dict['query_result']) | |||
| print() | |||
| class TableQuestionAnswering(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| @@ -22,20 +73,18 @@ class TableQuestionAnswering(unittest.TestCase): | |||
| self.model_id = 'damo/nlp_convai_text2sql_pretrain_cn' | |||
| model_id = 'damo/nlp_convai_text2sql_pretrain_cn' | |||
| test_case = { | |||
| 'utterance': | |||
| ['长江流域的小(2)型水库的库容总量是多少?', '那平均值是多少?', '那水库的名称呢?', '换成中型的呢?'] | |||
| } | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_by_direct_model_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| preprocessor = TableQuestionAnsweringPreprocessor(model_dir=cache_path) | |||
| pipelines = [ | |||
| TableQuestionAnsweringPipeline( | |||
| model=cache_path, preprocessor=preprocessor) | |||
| pipeline( | |||
| Tasks.table_question_answering, | |||
| model=cache_path, | |||
| preprocessor=preprocessor) | |||
| ] | |||
| tableqa_tracking_and_print_results(self.test_case, pipelines) | |||
| tableqa_tracking_and_print_results_with_history(pipelines) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| @@ -43,15 +92,17 @@ class TableQuestionAnswering(unittest.TestCase): | |||
| preprocessor = TableQuestionAnsweringPreprocessor( | |||
| model_dir=model.model_dir) | |||
| pipelines = [ | |||
| TableQuestionAnsweringPipeline( | |||
| model=model, preprocessor=preprocessor) | |||
| pipeline( | |||
| Tasks.table_question_answering, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| ] | |||
| tableqa_tracking_and_print_results(self.test_case, pipelines) | |||
| tableqa_tracking_and_print_results_with_history(pipelines) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_task(self): | |||
| pipelines = [pipeline(Tasks.table_question_answering, self.model_id)] | |||
| tableqa_tracking_and_print_results(self.test_case, pipelines) | |||
| tableqa_tracking_and_print_results_with_history(pipelines) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub_with_other_classes(self): | |||
| @@ -60,15 +111,24 @@ class TableQuestionAnswering(unittest.TestCase): | |||
| os.path.join(model.model_dir, ModelFile.VOCAB_FILE)) | |||
| db = Database( | |||
| tokenizer=self.tokenizer, | |||
| table_file_path=os.path.join(model.model_dir, 'table.json'), | |||
| syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt')) | |||
| table_file_path=[ | |||
| os.path.join(model.model_dir, 'databases', fname) | |||
| for fname in os.listdir( | |||
| os.path.join(model.model_dir, 'databases')) | |||
| ], | |||
| syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'), | |||
| is_use_sqlite=True) | |||
| preprocessor = TableQuestionAnsweringPreprocessor( | |||
| model_dir=model.model_dir, db=db) | |||
| pipelines = [ | |||
| TableQuestionAnsweringPipeline( | |||
| model=model, preprocessor=preprocessor, db=db) | |||
| pipeline( | |||
| Tasks.table_question_answering, | |||
| model=model, | |||
| preprocessor=preprocessor, | |||
| db=db) | |||
| ] | |||
| tableqa_tracking_and_print_results(self.test_case, pipelines) | |||
| tableqa_tracking_and_print_results_without_history(pipelines) | |||
| tableqa_tracking_and_print_results_with_history(pipelines) | |||
| if __name__ == '__main__': | |||