相比于master上的tableqa,做出了如下修复:
1. 修复了schema linking中的问题。
2. 同时设置了有history sql和没有history sql的两种输入
3. 增加了sqlite执行逻辑,可以返回sql执行结果
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10365114
master
| @@ -3,9 +3,11 @@ | |||||
| import os | import os | ||||
| from typing import Dict | from typing import Dict | ||||
| import json | |||||
| import numpy | import numpy | ||||
| import torch | import torch | ||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| import tqdm | |||||
| from transformers import BertTokenizer | from transformers import BertTokenizer | ||||
| from modelscope.metainfo import Models | from modelscope.metainfo import Models | ||||
| @@ -82,7 +84,6 @@ class TableQuestionAnswering(Model): | |||||
| if ntok.startswith('##'): | if ntok.startswith('##'): | ||||
| ntok = ntok.replace('##', '') | ntok = ntok.replace('##', '') | ||||
| tok = nlu1[idx:idx + 1].lower() | tok = nlu1[idx:idx + 1].lower() | ||||
| if ntok == tok: | if ntok == tok: | ||||
| conv_dict[i] = [idx, idx + 1] | conv_dict[i] = [idx, idx + 1] | ||||
| @@ -37,6 +37,7 @@ class OutputKeys(object): | |||||
| WORD = 'word' | WORD = 'word' | ||||
| KWS_LIST = 'kws_list' | KWS_LIST = 'kws_list' | ||||
| HISTORY = 'history' | HISTORY = 'history' | ||||
| QUERT_RESULT = 'query_result' | |||||
| TIMESTAMPS = 'timestamps' | TIMESTAMPS = 'timestamps' | ||||
| SHOT_NUM = 'shot_num' | SHOT_NUM = 'shot_num' | ||||
| SCENE_NUM = 'scene_num' | SCENE_NUM = 'scene_num' | ||||
| @@ -2,6 +2,8 @@ | |||||
| import os | import os | ||||
| from typing import Any, Dict, Union | from typing import Any, Dict, Union | ||||
| import json | |||||
| import torch | |||||
| from transformers import BertTokenizer | from transformers import BertTokenizer | ||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| @@ -230,14 +232,16 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||||
| str_sel_list.append(header_name) | str_sel_list.append(header_name) | ||||
| sql_sel_list.append(header_id) | sql_sel_list.append(header_id) | ||||
| else: | else: | ||||
| str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '( ' | |||||
| + header_name + ' )') | |||||
| sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '( ' | |||||
| + header_id + ' )') | |||||
| str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' | |||||
| + header_name + ')') | |||||
| sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' | |||||
| + header_id + ')') | |||||
| str_cond_list, sql_cond_list = [], [] | str_cond_list, sql_cond_list = [], [] | ||||
| for cond in sql['conds']: | for cond in sql['conds']: | ||||
| header_name = header_names[cond[0]] | header_name = header_names[cond[0]] | ||||
| if header_name == '空列': | |||||
| continue | |||||
| header_id = '`%s`.`%s`' % (table['table_id'], header_ids[cond[0]]) | header_id = '`%s`.`%s`' % (table['table_id'], header_ids[cond[0]]) | ||||
| op = self.cond_ops[cond[1]] | op = self.cond_ops[cond[1]] | ||||
| value = cond[2] | value = cond[2] | ||||
| @@ -248,12 +252,17 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||||
| cond = ' ' + self.cond_conn_ops[sql['cond_conn_op']] + ' ' | cond = ' ' + self.cond_conn_ops[sql['cond_conn_op']] + ' ' | ||||
| final_str = 'SELECT %s FROM %s WHERE %s' % (', '.join(str_sel_list), | |||||
| table['table_name'], | |||||
| cond.join(str_cond_list)) | |||||
| final_sql = 'SELECT %s FROM `%s` WHERE %s' % (', '.join(sql_sel_list), | |||||
| table['table_id'], | |||||
| cond.join(sql_cond_list)) | |||||
| if len(str_cond_list) != 0: | |||||
| final_str = 'SELECT %s FROM %s WHERE %s' % (', '.join( | |||||
| str_sel_list), table['table_name'], cond.join(str_cond_list)) | |||||
| final_sql = 'SELECT %s FROM `%s` WHERE %s' % (', '.join( | |||||
| sql_sel_list), table['table_id'], cond.join(sql_cond_list)) | |||||
| else: | |||||
| final_str = 'SELECT %s FROM %s' % (', '.join(str_sel_list), | |||||
| table['table_name']) | |||||
| final_sql = 'SELECT %s FROM `%s`' % (', '.join(sql_sel_list), | |||||
| table['table_id']) | |||||
| sql = SQLQuery( | sql = SQLQuery( | ||||
| string=final_str, query=final_sql, sql_result=result['sql']) | string=final_str, query=final_sql, sql_result=result['sql']) | ||||
| @@ -274,9 +283,39 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||||
| history_sql=history_sql, | history_sql=history_sql, | ||||
| result=result, | result=result, | ||||
| table=self.db.tables[result['table_id']]) | table=self.db.tables[result['table_id']]) | ||||
| result['sql']['from'] = [result['table_id']] | |||||
| sql = self.sql_dict_to_str( | sql = self.sql_dict_to_str( | ||||
| result=result, table=self.db.tables[result['table_id']]) | result=result, table=self.db.tables[result['table_id']]) | ||||
| output = {OutputKeys.OUTPUT: sql, OutputKeys.HISTORY: result['sql']} | |||||
| # add sqlite | |||||
| if self.db.is_use_sqlite: | |||||
| try: | |||||
| cursor = self.db.connection_obj.cursor().execute(sql.query) | |||||
| names = [{ | |||||
| 'name': | |||||
| description[0], | |||||
| 'label': | |||||
| self.db.tables[result['table_id']]['headerid2name'].get( | |||||
| description[0], description[0]) | |||||
| } for description in cursor.description] | |||||
| cells = [] | |||||
| for res in cursor.fetchall(): | |||||
| row = {} | |||||
| for name, cell in zip(names, res): | |||||
| row[name['name']] = cell | |||||
| cells.append(row) | |||||
| tabledata = {'headers': names, 'cells': cells} | |||||
| except Exception: | |||||
| tabledata = {'headers': [], 'cells': []} | |||||
| else: | |||||
| tabledata = {'headers': [], 'cells': []} | |||||
| output = { | |||||
| OutputKeys.OUTPUT: sql, | |||||
| OutputKeys.HISTORY: result['sql'], | |||||
| OutputKeys.QUERT_RESULT: json.dumps(tabledata, ensure_ascii=False), | |||||
| } | |||||
| return output | return output | ||||
| def _collate_fn(self, data): | def _collate_fn(self, data): | ||||
| @@ -1,4 +1,6 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import sqlite3 | |||||
| import json | import json | ||||
| import tqdm | import tqdm | ||||
| @@ -7,18 +9,38 @@ from modelscope.preprocessors.star3.fields.struct import Trie | |||||
| class Database: | class Database: | ||||
| def __init__(self, tokenizer, table_file_path, syn_dict_file_path): | |||||
| def __init__(self, | |||||
| tokenizer, | |||||
| table_file_path, | |||||
| syn_dict_file_path, | |||||
| is_use_sqlite=False): | |||||
| self.tokenizer = tokenizer | self.tokenizer = tokenizer | ||||
| self.is_use_sqlite = is_use_sqlite | |||||
| if self.is_use_sqlite: | |||||
| self.connection_obj = sqlite3.connect(':memory:') | |||||
| self.type_dict = {'text': 'TEXT', 'number': 'INT', 'date': 'TEXT'} | |||||
| self.tables = self.init_tables(table_file_path=table_file_path) | self.tables = self.init_tables(table_file_path=table_file_path) | ||||
| self.syn_dict = self.init_syn_dict( | self.syn_dict = self.init_syn_dict( | ||||
| syn_dict_file_path=syn_dict_file_path) | syn_dict_file_path=syn_dict_file_path) | ||||
| def __del__(self): | |||||
| if self.is_use_sqlite: | |||||
| self.connection_obj.close() | |||||
| def init_tables(self, table_file_path): | def init_tables(self, table_file_path): | ||||
| tables = {} | tables = {} | ||||
| lines = [] | lines = [] | ||||
| with open(table_file_path, 'r') as fo: | |||||
| for line in fo: | |||||
| lines.append(line) | |||||
| if type(table_file_path) == str: | |||||
| with open(table_file_path, 'r') as fo: | |||||
| for line in fo: | |||||
| lines.append(line) | |||||
| elif type(table_file_path) == list: | |||||
| for path in table_file_path: | |||||
| with open(path, 'r') as fo: | |||||
| for line in fo: | |||||
| lines.append(line) | |||||
| else: | |||||
| raise ValueError() | |||||
| for line in tqdm.tqdm(lines, desc='Load Tables'): | for line in tqdm.tqdm(lines, desc='Load Tables'): | ||||
| table = json.loads(line.strip()) | table = json.loads(line.strip()) | ||||
| @@ -34,6 +56,9 @@ class Database: | |||||
| headers_tokens.append(empty_column) | headers_tokens.append(empty_column) | ||||
| table['tablelen'] = table_header_length | table['tablelen'] = table_header_length | ||||
| table['header_tok'] = headers_tokens | table['header_tok'] = headers_tokens | ||||
| table['headerid2name'] = {} | |||||
| for hid, hname in zip(table['header_id'], table['header_name']): | |||||
| table['headerid2name'][hid] = hname | |||||
| table['header_types'].append('null') | table['header_types'].append('null') | ||||
| table['header_units'] = [ | table['header_units'] = [ | ||||
| @@ -51,6 +76,26 @@ class Database: | |||||
| trie_set[ii].insert(word, word) | trie_set[ii].insert(word, word) | ||||
| table['value_trie'] = trie_set | table['value_trie'] = trie_set | ||||
| # create sqlite | |||||
| if self.is_use_sqlite: | |||||
| cursor_obj = self.connection_obj.cursor() | |||||
| cursor_obj.execute('DROP TABLE IF EXISTS %s' % | |||||
| (table['table_id'])) | |||||
| header_string = ', '.join([ | |||||
| '%s %s' % | |||||
| (name, self.type_dict[htype]) for name, htype in zip( | |||||
| table['header_id'], table['header_types']) | |||||
| ]) | |||||
| create_table_string = 'CREATE TABLE %s (%s);' % ( | |||||
| table['table_id'], header_string) | |||||
| cursor_obj.execute(create_table_string) | |||||
| for row in table['rows']: | |||||
| value_string = ', '.join(['"%s"' % (val) for val in row]) | |||||
| insert_row_string = 'INSERT INTO %s VALUES(%s)' % ( | |||||
| table['table_id'], value_string) | |||||
| cursor_obj.execute(insert_row_string) | |||||
| tables[table['table_id']] = table | tables[table['table_id']] = table | ||||
| return tables | return tables | ||||
| @@ -287,7 +287,13 @@ class SchemaLinker: | |||||
| return match_len / (len(nlu_t) + 0.1) | return match_len / (len(nlu_t) + 0.1) | ||||
| def get_entity_linking(self, tokenizer, nlu, nlu_t, tables, col_syn_dict): | |||||
| def get_entity_linking(self, | |||||
| tokenizer, | |||||
| nlu, | |||||
| nlu_t, | |||||
| tables, | |||||
| col_syn_dict, | |||||
| history_sql=None): | |||||
| """ | """ | ||||
| get linking between question and schema column | get linking between question and schema column | ||||
| """ | """ | ||||
| @@ -305,8 +311,7 @@ class SchemaLinker: | |||||
| typeinfos = [] | typeinfos = [] | ||||
| for ii, column in enumerate(table['header_name']): | for ii, column in enumerate(table['header_name']): | ||||
| column = column.lower() | column = column.lower() | ||||
| column_new = re.sub('(.*?)', '', column) | |||||
| column_new = re.sub('(.*?)', '', column_new) | |||||
| column_new = column | |||||
| cphrase, cscore = self.get_match_phrase( | cphrase, cscore = self.get_match_phrase( | ||||
| nlu.lower(), column_new) | nlu.lower(), column_new) | ||||
| if cscore > 0.3 and cphrase.strip() != '': | if cscore > 0.3 and cphrase.strip() != '': | ||||
| @@ -330,7 +335,6 @@ class SchemaLinker: | |||||
| for cell in ans.keys(): | for cell in ans.keys(): | ||||
| vphrase = cell | vphrase = cell | ||||
| vscore = 1.0 | vscore = 1.0 | ||||
| # print("trie_set find:", cell, ans[cell]) | |||||
| phrase_tok = tokenizer.tokenize(vphrase) | phrase_tok = tokenizer.tokenize(vphrase) | ||||
| if len(phrase_tok) == 0 or len(vphrase) < 2: | if len(phrase_tok) == 0 or len(vphrase) < 2: | ||||
| continue | continue | ||||
| @@ -408,16 +412,25 @@ class SchemaLinker: | |||||
| match_score = self.get_table_match_score(nlu_t, schema_link) | match_score = self.get_table_match_score(nlu_t, schema_link) | ||||
| search_result = { | search_result = { | ||||
| 'table_id': table['table_id'], | |||||
| 'question_knowledge': final_question, | |||||
| 'header_knowledge': final_header, | |||||
| 'schema_link': schema_link, | |||||
| 'match_score': match_score | |||||
| 'table_id': | |||||
| table['table_id'], | |||||
| 'question_knowledge': | |||||
| final_question, | |||||
| 'header_knowledge': | |||||
| final_header, | |||||
| 'schema_link': | |||||
| schema_link, | |||||
| 'match_score': | |||||
| match_score, | |||||
| 'table_score': | |||||
| int(table['table_id'] == history_sql['from'][0]) | |||||
| if history_sql is not None else 0 | |||||
| } | } | ||||
| search_result_list.append(search_result) | search_result_list.append(search_result) | ||||
| search_result_list = sorted( | search_result_list = sorted( | ||||
| search_result_list, key=lambda x: x['match_score'], | |||||
| search_result_list, | |||||
| key=lambda x: (x['match_score'], x['table_score']), | |||||
| reverse=True)[0:4] | reverse=True)[0:4] | ||||
| return search_result_list | return search_result_list | ||||
| @@ -95,7 +95,7 @@ class TableQuestionAnsweringPreprocessor(Preprocessor): | |||||
| # tokenize question | # tokenize question | ||||
| question = data['question'] | question = data['question'] | ||||
| history_sql = data['history_sql'] | |||||
| history_sql = data.get('history_sql', None) | |||||
| nlu = question.lower() | nlu = question.lower() | ||||
| nlu_t = self.tokenizer.tokenize(nlu) | nlu_t = self.tokenizer.tokenize(nlu) | ||||
| @@ -105,7 +105,8 @@ class TableQuestionAnsweringPreprocessor(Preprocessor): | |||||
| nlu=nlu, | nlu=nlu, | ||||
| nlu_t=nlu_t, | nlu_t=nlu_t, | ||||
| tables=self.db.tables, | tables=self.db.tables, | ||||
| col_syn_dict=self.db.syn_dict) | |||||
| col_syn_dict=self.db.syn_dict, | |||||
| history_sql=history_sql) | |||||
| # collect data | # collect data | ||||
| datas = self.construct_data( | datas = self.construct_data( | ||||
| @@ -2,8 +2,7 @@ from typing import List | |||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.pipelines.nlp import (ConversationalTextToSqlPipeline, | from modelscope.pipelines.nlp import (ConversationalTextToSqlPipeline, | ||||
| DialogStateTrackingPipeline, | |||||
| TableQuestionAnsweringPipeline) | |||||
| DialogStateTrackingPipeline) | |||||
| def text2sql_tracking_and_print_results( | def text2sql_tracking_and_print_results( | ||||
| @@ -42,17 +41,3 @@ def tracking_and_print_dialog_states( | |||||
| print(json.dumps(result)) | print(json.dumps(result)) | ||||
| history_states.extend([result[OutputKeys.OUTPUT], {}]) | history_states.extend([result[OutputKeys.OUTPUT], {}]) | ||||
| def tableqa_tracking_and_print_results( | |||||
| test_case, pipelines: List[TableQuestionAnsweringPipeline]): | |||||
| for pipeline in pipelines: | |||||
| historical_queries = None | |||||
| for question in test_case['utterance']: | |||||
| output_dict = pipeline({ | |||||
| 'question': question, | |||||
| 'history_sql': historical_queries | |||||
| }) | |||||
| print('output_dict', output_dict['output'].string, | |||||
| output_dict['output'].query) | |||||
| historical_queries = output_dict['history'] | |||||
| @@ -1,6 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | import os | ||||
| import unittest | import unittest | ||||
| from typing import List | |||||
| from transformers import BertTokenizer | from transformers import BertTokenizer | ||||
| @@ -11,10 +12,60 @@ from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline | |||||
| from modelscope.preprocessors import TableQuestionAnsweringPreprocessor | from modelscope.preprocessors import TableQuestionAnsweringPreprocessor | ||||
| from modelscope.preprocessors.star3.fields.database import Database | from modelscope.preprocessors.star3.fields.database import Database | ||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| from modelscope.utils.nlp.nlp_utils import tableqa_tracking_and_print_results | |||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| def tableqa_tracking_and_print_results_with_history( | |||||
| pipelines: List[TableQuestionAnsweringPipeline]): | |||||
| test_case = { | |||||
| 'utterance': [ | |||||
| '有哪些风险类型?', | |||||
| '风险类型有多少种?', | |||||
| '珠江流域的小(2)型水库的库容总量是多少?', | |||||
| '那平均值是多少?', | |||||
| '那水库的名称呢?', | |||||
| '换成中型的呢?', | |||||
| '枣庄营业厅的电话', | |||||
| '那地址呢?', | |||||
| '枣庄营业厅的电话和地址', | |||||
| ] | |||||
| } | |||||
| for p in pipelines: | |||||
| historical_queries = None | |||||
| for question in test_case['utterance']: | |||||
| output_dict = p({ | |||||
| 'question': question, | |||||
| 'history_sql': historical_queries | |||||
| }) | |||||
| print('question', question) | |||||
| print('sql text:', output_dict['output'].string) | |||||
| print('sql query:', output_dict['output'].query) | |||||
| print('query result:', output_dict['query_result']) | |||||
| print() | |||||
| historical_queries = output_dict['history'] | |||||
| def tableqa_tracking_and_print_results_without_history( | |||||
| pipelines: List[TableQuestionAnsweringPipeline]): | |||||
| test_case = { | |||||
| 'utterance': [ | |||||
| '有哪些风险类型?', | |||||
| '风险类型有多少种?', | |||||
| '珠江流域的小(2)型水库的库容总量是多少?', | |||||
| '枣庄营业厅的电话', | |||||
| '枣庄营业厅的电话和地址', | |||||
| ] | |||||
| } | |||||
| for p in pipelines: | |||||
| for question in test_case['utterance']: | |||||
| output_dict = p({'question': question}) | |||||
| print('question', question) | |||||
| print('sql text:', output_dict['output'].string) | |||||
| print('sql query:', output_dict['output'].query) | |||||
| print('query result:', output_dict['query_result']) | |||||
| print() | |||||
| class TableQuestionAnswering(unittest.TestCase): | class TableQuestionAnswering(unittest.TestCase): | ||||
| def setUp(self) -> None: | def setUp(self) -> None: | ||||
| @@ -22,20 +73,18 @@ class TableQuestionAnswering(unittest.TestCase): | |||||
| self.model_id = 'damo/nlp_convai_text2sql_pretrain_cn' | self.model_id = 'damo/nlp_convai_text2sql_pretrain_cn' | ||||
| model_id = 'damo/nlp_convai_text2sql_pretrain_cn' | model_id = 'damo/nlp_convai_text2sql_pretrain_cn' | ||||
| test_case = { | |||||
| 'utterance': | |||||
| ['长江流域的小(2)型水库的库容总量是多少?', '那平均值是多少?', '那水库的名称呢?', '换成中型的呢?'] | |||||
| } | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_run_by_direct_model_download(self): | def test_run_by_direct_model_download(self): | ||||
| cache_path = snapshot_download(self.model_id) | cache_path = snapshot_download(self.model_id) | ||||
| preprocessor = TableQuestionAnsweringPreprocessor(model_dir=cache_path) | preprocessor = TableQuestionAnsweringPreprocessor(model_dir=cache_path) | ||||
| pipelines = [ | pipelines = [ | ||||
| TableQuestionAnsweringPipeline( | |||||
| model=cache_path, preprocessor=preprocessor) | |||||
| pipeline( | |||||
| Tasks.table_question_answering, | |||||
| model=cache_path, | |||||
| preprocessor=preprocessor) | |||||
| ] | ] | ||||
| tableqa_tracking_and_print_results(self.test_case, pipelines) | |||||
| tableqa_tracking_and_print_results_with_history(pipelines) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_run_with_model_from_modelhub(self): | def test_run_with_model_from_modelhub(self): | ||||
| @@ -43,15 +92,17 @@ class TableQuestionAnswering(unittest.TestCase): | |||||
| preprocessor = TableQuestionAnsweringPreprocessor( | preprocessor = TableQuestionAnsweringPreprocessor( | ||||
| model_dir=model.model_dir) | model_dir=model.model_dir) | ||||
| pipelines = [ | pipelines = [ | ||||
| TableQuestionAnsweringPipeline( | |||||
| model=model, preprocessor=preprocessor) | |||||
| pipeline( | |||||
| Tasks.table_question_answering, | |||||
| model=model, | |||||
| preprocessor=preprocessor) | |||||
| ] | ] | ||||
| tableqa_tracking_and_print_results(self.test_case, pipelines) | |||||
| tableqa_tracking_and_print_results_with_history(pipelines) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_run_with_model_from_task(self): | def test_run_with_model_from_task(self): | ||||
| pipelines = [pipeline(Tasks.table_question_answering, self.model_id)] | pipelines = [pipeline(Tasks.table_question_answering, self.model_id)] | ||||
| tableqa_tracking_and_print_results(self.test_case, pipelines) | |||||
| tableqa_tracking_and_print_results_with_history(pipelines) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_run_with_model_from_modelhub_with_other_classes(self): | def test_run_with_model_from_modelhub_with_other_classes(self): | ||||
| @@ -60,15 +111,24 @@ class TableQuestionAnswering(unittest.TestCase): | |||||
| os.path.join(model.model_dir, ModelFile.VOCAB_FILE)) | os.path.join(model.model_dir, ModelFile.VOCAB_FILE)) | ||||
| db = Database( | db = Database( | ||||
| tokenizer=self.tokenizer, | tokenizer=self.tokenizer, | ||||
| table_file_path=os.path.join(model.model_dir, 'table.json'), | |||||
| syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt')) | |||||
| table_file_path=[ | |||||
| os.path.join(model.model_dir, 'databases', fname) | |||||
| for fname in os.listdir( | |||||
| os.path.join(model.model_dir, 'databases')) | |||||
| ], | |||||
| syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'), | |||||
| is_use_sqlite=True) | |||||
| preprocessor = TableQuestionAnsweringPreprocessor( | preprocessor = TableQuestionAnsweringPreprocessor( | ||||
| model_dir=model.model_dir, db=db) | model_dir=model.model_dir, db=db) | ||||
| pipelines = [ | pipelines = [ | ||||
| TableQuestionAnsweringPipeline( | |||||
| model=model, preprocessor=preprocessor, db=db) | |||||
| pipeline( | |||||
| Tasks.table_question_answering, | |||||
| model=model, | |||||
| preprocessor=preprocessor, | |||||
| db=db) | |||||
| ] | ] | ||||
| tableqa_tracking_and_print_results(self.test_case, pipelines) | |||||
| tableqa_tracking_and_print_results_without_history(pipelines) | |||||
| tableqa_tracking_and_print_results_with_history(pipelines) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||