主要做了如下修改:
1. 加入了同义词词典
2. 对SQL进行后处理,如果包含排序,则将空列转化成Primary列
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10670121
master
| @@ -231,19 +231,6 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||
| header_ids = table['header_id'] + ['null'] | |||
| sql = result['sql'] | |||
| str_sel_list, sql_sel_list = [], [] | |||
| for idx, sel in enumerate(sql['sel']): | |||
| header_name = header_names[sel] | |||
| header_id = '`%s`.`%s`' % (table['table_id'], header_ids[sel]) | |||
| if sql['agg'][idx] == 0: | |||
| str_sel_list.append(header_name) | |||
| sql_sel_list.append(header_id) | |||
| else: | |||
| str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' | |||
| + header_name + ')') | |||
| sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' | |||
| + header_id + ')') | |||
| str_cond_list, sql_cond_list = [], [] | |||
| where_conds, orderby_conds = [], [] | |||
| for cond in sql['conds']: | |||
| @@ -285,9 +272,34 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||
| if is_in: | |||
| str_orderby += ' LIMIT %d' % (limit_num) | |||
| sql_orderby += ' LIMIT %d' % (limit_num) | |||
| # post process null column | |||
| for idx, sel in enumerate(sql['sel']): | |||
| if sel == len(header_ids) - 1: | |||
| primary_sel = 0 | |||
| for index, attrib in enumerate(table['header_attribute']): | |||
| if attrib == 'PRIMARY': | |||
| primary_sel = index | |||
| break | |||
| if primary_sel not in sql['sel']: | |||
| sql['sel'][idx] = primary_sel | |||
| else: | |||
| del sql['sel'][idx] | |||
| else: | |||
| str_orderby = '' | |||
| str_sel_list, sql_sel_list = [], [] | |||
| for idx, sel in enumerate(sql['sel']): | |||
| header_name = header_names[sel] | |||
| header_id = '`%s`.`%s`' % (table['table_id'], header_ids[sel]) | |||
| if sql['agg'][idx] == 0: | |||
| str_sel_list.append(header_name) | |||
| sql_sel_list.append(header_id) | |||
| else: | |||
| str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' | |||
| + header_name + ')') | |||
| sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' | |||
| + header_id + ')') | |||
| if len(str_cond_list) != 0 and len(str_orderby) != 0: | |||
| final_str = 'SELECT %s FROM %s WHERE %s ORDER BY %s' % ( | |||
| ', '.join(str_sel_list), table['table_name'], str_where_conds, | |||
| @@ -20,9 +20,9 @@ class Database: | |||
| self.connection_obj = sqlite3.connect( | |||
| ':memory:', check_same_thread=False) | |||
| self.type_dict = {'text': 'TEXT', 'number': 'INT', 'date': 'TEXT'} | |||
| self.tables = self.init_tables(table_file_path=table_file_path) | |||
| self.syn_dict = self.init_syn_dict( | |||
| syn_dict_file_path=syn_dict_file_path) | |||
| self.tables = self.init_tables(table_file_path=table_file_path) | |||
| def __del__(self): | |||
| if self.is_use_sqlite: | |||
| @@ -75,6 +75,10 @@ class Database: | |||
| continue | |||
| word = str(cell).strip().lower() | |||
| trie_set[ii].insert(word, word) | |||
| if word in self.syn_dict.keys(): | |||
| for term in self.syn_dict[word]: | |||
| if term.strip() != '': | |||
| trie_set[ii].insert(term, word) | |||
| table['value_trie'] = trie_set | |||
| @@ -24,13 +24,10 @@ def tableqa_tracking_and_print_results_with_history( | |||
| 'utterance': [ | |||
| '有哪些风险类型?', | |||
| '风险类型有多少种?', | |||
| '珠江流域的小(2)型水库的库容总量是多少?', | |||
| '珠江流域的小型水库的库容总量是多少?', | |||
| '那平均值是多少?', | |||
| '那水库的名称呢?', | |||
| '换成中型的呢?', | |||
| '枣庄营业厅的电话', | |||
| '那地址呢?', | |||
| '枣庄营业厅的电话和地址', | |||
| ] | |||
| } | |||
| for p in pipelines: | |||
| @@ -55,9 +52,7 @@ def tableqa_tracking_and_print_results_without_history( | |||
| 'utterance': [ | |||
| '有哪些风险类型?', | |||
| '风险类型有多少种?', | |||
| '珠江流域的小(2)型水库的库容总量是多少?', | |||
| '枣庄营业厅的电话', | |||
| '枣庄营业厅的电话和地址', | |||
| '珠江流域的小型水库的库容总量是多少?', | |||
| ] | |||
| } | |||
| for p in pipelines: | |||
| @@ -77,13 +72,10 @@ def tableqa_tracking_and_print_results_with_tableid( | |||
| 'utterance': [ | |||
| ['有哪些风险类型?', 'fund'], | |||
| ['风险类型有多少种?', 'reservoir'], | |||
| ['珠江流域的小(2)型水库的库容总量是多少?', 'reservoir'], | |||
| ['珠江流域的小型水库的库容总量是多少?', 'reservoir'], | |||
| ['那平均值是多少?', 'reservoir'], | |||
| ['那水库的名称呢?', 'reservoir'], | |||
| ['换成中型的呢?', 'reservoir'], | |||
| ['枣庄营业厅的电话', 'business'], | |||
| ['那地址呢?', 'business'], | |||
| ['枣庄营业厅的电话和地址', 'business'], | |||
| ], | |||
| } | |||
| for p in pipelines: | |||
| @@ -157,7 +149,7 @@ class TableQuestionAnswering(unittest.TestCase): | |||
| os.path.join(model.model_dir, 'databases')) | |||
| ], | |||
| syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'), | |||
| is_use_sqlite=False) | |||
| is_use_sqlite=True) | |||
| preprocessor = TableQuestionAnsweringPreprocessor( | |||
| model_dir=model.model_dir, db=db) | |||
| pipelines = [ | |||