diff --git a/modelscope/pipelines/nlp/table_question_answering_pipeline.py b/modelscope/pipelines/nlp/table_question_answering_pipeline.py index b75a8153..bde78196 100644 --- a/modelscope/pipelines/nlp/table_question_answering_pipeline.py +++ b/modelscope/pipelines/nlp/table_question_answering_pipeline.py @@ -231,19 +231,6 @@ class TableQuestionAnsweringPipeline(Pipeline): header_ids = table['header_id'] + ['null'] sql = result['sql'] - str_sel_list, sql_sel_list = [], [] - for idx, sel in enumerate(sql['sel']): - header_name = header_names[sel] - header_id = '`%s`.`%s`' % (table['table_id'], header_ids[sel]) - if sql['agg'][idx] == 0: - str_sel_list.append(header_name) - sql_sel_list.append(header_id) - else: - str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' - + header_name + ')') - sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' - + header_id + ')') - str_cond_list, sql_cond_list = [], [] where_conds, orderby_conds = [], [] for cond in sql['conds']: @@ -285,9 +272,34 @@ class TableQuestionAnsweringPipeline(Pipeline): if is_in: str_orderby += ' LIMIT %d' % (limit_num) sql_orderby += ' LIMIT %d' % (limit_num) + # post process null column + for idx, sel in enumerate(sql['sel']): + if sel == len(header_ids) - 1: + primary_sel = 0 + for index, attrib in enumerate(table['header_attribute']): + if attrib == 'PRIMARY': + primary_sel = index + break + if primary_sel not in sql['sel']: + sql['sel'][idx] = primary_sel + else: + del sql['sel'][idx] else: str_orderby = '' + str_sel_list, sql_sel_list = [], [] + for idx, sel in enumerate(sql['sel']): + header_name = header_names[sel] + header_id = '`%s`.`%s`' % (table['table_id'], header_ids[sel]) + if sql['agg'][idx] == 0: + str_sel_list.append(header_name) + sql_sel_list.append(header_id) + else: + str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' + + header_name + ')') + sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' + + header_id + ')') + if len(str_cond_list) != 0 and len(str_orderby) != 0: final_str = 'SELECT %s FROM %s WHERE %s ORDER BY %s' % ( ', '.join(str_sel_list), table['table_name'], str_where_conds, diff --git a/modelscope/preprocessors/nlp/space_T_cn/fields/database.py b/modelscope/preprocessors/nlp/space_T_cn/fields/database.py index 2fef8d7e..5ceb5c79 100644 --- a/modelscope/preprocessors/nlp/space_T_cn/fields/database.py +++ b/modelscope/preprocessors/nlp/space_T_cn/fields/database.py @@ -20,9 +20,9 @@ class Database: self.connection_obj = sqlite3.connect( ':memory:', check_same_thread=False) self.type_dict = {'text': 'TEXT', 'number': 'INT', 'date': 'TEXT'} - self.tables = self.init_tables(table_file_path=table_file_path) self.syn_dict = self.init_syn_dict( syn_dict_file_path=syn_dict_file_path) + self.tables = self.init_tables(table_file_path=table_file_path) def __del__(self): if self.is_use_sqlite: @@ -75,6 +75,10 @@ class Database: continue word = str(cell).strip().lower() trie_set[ii].insert(word, word) + if word in self.syn_dict.keys(): + for term in self.syn_dict[word]: + if term.strip() != '': + trie_set[ii].insert(term, word) table['value_trie'] = trie_set diff --git a/tests/pipelines/test_table_question_answering.py b/tests/pipelines/test_table_question_answering.py index 825d8f23..9faed993 100644 --- a/tests/pipelines/test_table_question_answering.py +++ b/tests/pipelines/test_table_question_answering.py @@ -24,13 +24,10 @@ def tableqa_tracking_and_print_results_with_history( 'utterance': [ '有哪些风险类型?', '风险类型有多少种?', - '珠江流域的小(2)型水库的库容总量是多少?', + '珠江流域的小型水库的库容总量是多少?', '那平均值是多少?', '那水库的名称呢?', '换成中型的呢?', - '枣庄营业厅的电话', - '那地址呢?', - '枣庄营业厅的电话和地址', ] } for p in pipelines: @@ -55,9 +52,7 @@ def tableqa_tracking_and_print_results_without_history( 'utterance': [ '有哪些风险类型?', '风险类型有多少种?', - '珠江流域的小(2)型水库的库容总量是多少?', - '枣庄营业厅的电话', - '枣庄营业厅的电话和地址', + '珠江流域的小型水库的库容总量是多少?', ] } for p in pipelines: @@ -77,13 +72,10 @@ def tableqa_tracking_and_print_results_with_tableid( 'utterance': [ ['有哪些风险类型?', 'fund'], ['风险类型有多少种?', 'reservoir'], - ['珠江流域的小(2)型水库的库容总量是多少?', 'reservoir'], + ['珠江流域的小型水库的库容总量是多少?', 'reservoir'], ['那平均值是多少?', 'reservoir'], ['那水库的名称呢?', 'reservoir'], ['换成中型的呢?', 'reservoir'], - ['枣庄营业厅的电话', 'business'], - ['那地址呢?', 'business'], - ['枣庄营业厅的电话和地址', 'business'], ], } for p in pipelines: @@ -157,7 +149,7 @@ class TableQuestionAnswering(unittest.TestCase): os.path.join(model.model_dir, 'databases')) ], syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'), - is_use_sqlite=False) + is_use_sqlite=True) preprocessor = TableQuestionAnsweringPreprocessor( model_dir=model.model_dir, db=db) pipelines = [