主要做了如下修改:
1. 加入了同义词词典
2. 对SQL进行后处理,如果包含排序,则将空列转化成Primary列
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10670121
master
| @@ -231,19 +231,6 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||||
| header_ids = table['header_id'] + ['null'] | header_ids = table['header_id'] + ['null'] | ||||
| sql = result['sql'] | sql = result['sql'] | ||||
| str_sel_list, sql_sel_list = [], [] | |||||
| for idx, sel in enumerate(sql['sel']): | |||||
| header_name = header_names[sel] | |||||
| header_id = '`%s`.`%s`' % (table['table_id'], header_ids[sel]) | |||||
| if sql['agg'][idx] == 0: | |||||
| str_sel_list.append(header_name) | |||||
| sql_sel_list.append(header_id) | |||||
| else: | |||||
| str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' | |||||
| + header_name + ')') | |||||
| sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' | |||||
| + header_id + ')') | |||||
| str_cond_list, sql_cond_list = [], [] | str_cond_list, sql_cond_list = [], [] | ||||
| where_conds, orderby_conds = [], [] | where_conds, orderby_conds = [], [] | ||||
| for cond in sql['conds']: | for cond in sql['conds']: | ||||
| @@ -285,9 +272,34 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||||
| if is_in: | if is_in: | ||||
| str_orderby += ' LIMIT %d' % (limit_num) | str_orderby += ' LIMIT %d' % (limit_num) | ||||
| sql_orderby += ' LIMIT %d' % (limit_num) | sql_orderby += ' LIMIT %d' % (limit_num) | ||||
| # post process null column | |||||
| for idx, sel in enumerate(sql['sel']): | |||||
| if sel == len(header_ids) - 1: | |||||
| primary_sel = 0 | |||||
| for index, attrib in enumerate(table['header_attribute']): | |||||
| if attrib == 'PRIMARY': | |||||
| primary_sel = index | |||||
| break | |||||
| if primary_sel not in sql['sel']: | |||||
| sql['sel'][idx] = primary_sel | |||||
| else: | |||||
| del sql['sel'][idx] | |||||
| else: | else: | ||||
| str_orderby = '' | str_orderby = '' | ||||
| str_sel_list, sql_sel_list = [], [] | |||||
| for idx, sel in enumerate(sql['sel']): | |||||
| header_name = header_names[sel] | |||||
| header_id = '`%s`.`%s`' % (table['table_id'], header_ids[sel]) | |||||
| if sql['agg'][idx] == 0: | |||||
| str_sel_list.append(header_name) | |||||
| sql_sel_list.append(header_id) | |||||
| else: | |||||
| str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' | |||||
| + header_name + ')') | |||||
| sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' | |||||
| + header_id + ')') | |||||
| if len(str_cond_list) != 0 and len(str_orderby) != 0: | if len(str_cond_list) != 0 and len(str_orderby) != 0: | ||||
| final_str = 'SELECT %s FROM %s WHERE %s ORDER BY %s' % ( | final_str = 'SELECT %s FROM %s WHERE %s ORDER BY %s' % ( | ||||
| ', '.join(str_sel_list), table['table_name'], str_where_conds, | ', '.join(str_sel_list), table['table_name'], str_where_conds, | ||||
| @@ -20,9 +20,9 @@ class Database: | |||||
| self.connection_obj = sqlite3.connect( | self.connection_obj = sqlite3.connect( | ||||
| ':memory:', check_same_thread=False) | ':memory:', check_same_thread=False) | ||||
| self.type_dict = {'text': 'TEXT', 'number': 'INT', 'date': 'TEXT'} | self.type_dict = {'text': 'TEXT', 'number': 'INT', 'date': 'TEXT'} | ||||
| self.tables = self.init_tables(table_file_path=table_file_path) | |||||
| self.syn_dict = self.init_syn_dict( | self.syn_dict = self.init_syn_dict( | ||||
| syn_dict_file_path=syn_dict_file_path) | syn_dict_file_path=syn_dict_file_path) | ||||
| self.tables = self.init_tables(table_file_path=table_file_path) | |||||
| def __del__(self): | def __del__(self): | ||||
| if self.is_use_sqlite: | if self.is_use_sqlite: | ||||
| @@ -75,6 +75,10 @@ class Database: | |||||
| continue | continue | ||||
| word = str(cell).strip().lower() | word = str(cell).strip().lower() | ||||
| trie_set[ii].insert(word, word) | trie_set[ii].insert(word, word) | ||||
| if word in self.syn_dict.keys(): | |||||
| for term in self.syn_dict[word]: | |||||
| if term.strip() != '': | |||||
| trie_set[ii].insert(term, word) | |||||
| table['value_trie'] = trie_set | table['value_trie'] = trie_set | ||||
| @@ -24,13 +24,10 @@ def tableqa_tracking_and_print_results_with_history( | |||||
| 'utterance': [ | 'utterance': [ | ||||
| '有哪些风险类型?', | '有哪些风险类型?', | ||||
| '风险类型有多少种?', | '风险类型有多少种?', | ||||
| '珠江流域的小(2)型水库的库容总量是多少?', | |||||
| '珠江流域的小型水库的库容总量是多少?', | |||||
| '那平均值是多少?', | '那平均值是多少?', | ||||
| '那水库的名称呢?', | '那水库的名称呢?', | ||||
| '换成中型的呢?', | '换成中型的呢?', | ||||
| '枣庄营业厅的电话', | |||||
| '那地址呢?', | |||||
| '枣庄营业厅的电话和地址', | |||||
| ] | ] | ||||
| } | } | ||||
| for p in pipelines: | for p in pipelines: | ||||
| @@ -55,9 +52,7 @@ def tableqa_tracking_and_print_results_without_history( | |||||
| 'utterance': [ | 'utterance': [ | ||||
| '有哪些风险类型?', | '有哪些风险类型?', | ||||
| '风险类型有多少种?', | '风险类型有多少种?', | ||||
| '珠江流域的小(2)型水库的库容总量是多少?', | |||||
| '枣庄营业厅的电话', | |||||
| '枣庄营业厅的电话和地址', | |||||
| '珠江流域的小型水库的库容总量是多少?', | |||||
| ] | ] | ||||
| } | } | ||||
| for p in pipelines: | for p in pipelines: | ||||
| @@ -77,13 +72,10 @@ def tableqa_tracking_and_print_results_with_tableid( | |||||
| 'utterance': [ | 'utterance': [ | ||||
| ['有哪些风险类型?', 'fund'], | ['有哪些风险类型?', 'fund'], | ||||
| ['风险类型有多少种?', 'reservoir'], | ['风险类型有多少种?', 'reservoir'], | ||||
| ['珠江流域的小(2)型水库的库容总量是多少?', 'reservoir'], | |||||
| ['珠江流域的小型水库的库容总量是多少?', 'reservoir'], | |||||
| ['那平均值是多少?', 'reservoir'], | ['那平均值是多少?', 'reservoir'], | ||||
| ['那水库的名称呢?', 'reservoir'], | ['那水库的名称呢?', 'reservoir'], | ||||
| ['换成中型的呢?', 'reservoir'], | ['换成中型的呢?', 'reservoir'], | ||||
| ['枣庄营业厅的电话', 'business'], | |||||
| ['那地址呢?', 'business'], | |||||
| ['枣庄营业厅的电话和地址', 'business'], | |||||
| ], | ], | ||||
| } | } | ||||
| for p in pipelines: | for p in pipelines: | ||||
| @@ -157,7 +149,7 @@ class TableQuestionAnswering(unittest.TestCase): | |||||
| os.path.join(model.model_dir, 'databases')) | os.path.join(model.model_dir, 'databases')) | ||||
| ], | ], | ||||
| syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'), | syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'), | ||||
| is_use_sqlite=False) | |||||
| is_use_sqlite=True) | |||||
| preprocessor = TableQuestionAnsweringPreprocessor( | preprocessor = TableQuestionAnsweringPreprocessor( | ||||
| model_dir=model.model_dir, db=db) | model_dir=model.model_dir, db=db) | ||||
| pipelines = [ | pipelines = [ | ||||