1. 增加传入table_id
2. 将result和table的结构统一
3. 默认开启is_use_sqlite
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10492027
master
| @@ -72,6 +72,7 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||||
| action = self.action_ops[result['action']] | action = self.action_ops[result['action']] | ||||
| headers = table['header_name'] | headers = table['header_name'] | ||||
| current_sql = result['sql'] | current_sql = result['sql'] | ||||
| current_sql['from'] = [table['table_id']] | |||||
| if history_sql is None: | if history_sql is None: | ||||
| return current_sql | return current_sql | ||||
| @@ -216,10 +217,11 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||||
| else: | else: | ||||
| return current_sql | return current_sql | ||||
| def sql_dict_to_str(self, result, table): | |||||
| def sql_dict_to_str(self, result, tables): | |||||
| """ | """ | ||||
| convert sql struct to string | convert sql struct to string | ||||
| """ | """ | ||||
| table = tables[result['sql']['from'][0]] | |||||
| header_names = table['header_name'] + ['空列'] | header_names = table['header_name'] + ['空列'] | ||||
| header_ids = table['header_id'] + ['null'] | header_ids = table['header_id'] + ['null'] | ||||
| sql = result['sql'] | sql = result['sql'] | ||||
| @@ -279,42 +281,43 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||||
| """ | """ | ||||
| result = inputs['result'] | result = inputs['result'] | ||||
| history_sql = inputs['history_sql'] | history_sql = inputs['history_sql'] | ||||
| result['sql'] = self.post_process_multi_turn( | |||||
| history_sql=history_sql, | |||||
| result=result, | |||||
| table=self.db.tables[result['table_id']]) | |||||
| result['sql']['from'] = [result['table_id']] | |||||
| sql = self.sql_dict_to_str( | |||||
| result=result, table=self.db.tables[result['table_id']]) | |||||
| try: | |||||
| result['sql'] = self.post_process_multi_turn( | |||||
| history_sql=history_sql, | |||||
| result=result, | |||||
| table=self.db.tables[result['table_id']]) | |||||
| except Exception: | |||||
| result['sql'] = history_sql | |||||
| sql = self.sql_dict_to_str(result=result, tables=self.db.tables) | |||||
| # add sqlite | # add sqlite | ||||
| if self.db.is_use_sqlite: | if self.db.is_use_sqlite: | ||||
| try: | try: | ||||
| cursor = self.db.connection_obj.cursor().execute(sql.query) | cursor = self.db.connection_obj.cursor().execute(sql.query) | ||||
| names = [{ | |||||
| 'name': | |||||
| description[0], | |||||
| 'label': | |||||
| self.db.tables[result['table_id']]['headerid2name'].get( | |||||
| description[0], description[0]) | |||||
| } for description in cursor.description] | |||||
| cells = [] | |||||
| header_ids, header_names = [], [] | |||||
| for description in cursor.description: | |||||
| header_ids.append(self.db.tables[result['table_id']] | |||||
| ['headerid2name'].get( | |||||
| description[0], description[0])) | |||||
| header_names.append(description[0]) | |||||
| rows = [] | |||||
| for res in cursor.fetchall(): | for res in cursor.fetchall(): | ||||
| row = {} | |||||
| for name, cell in zip(names, res): | |||||
| row[name['name']] = cell | |||||
| cells.append(row) | |||||
| tabledata = {'headers': names, 'cells': cells} | |||||
| rows.append(list(res)) | |||||
| tabledata = { | |||||
| 'header_id': header_ids, | |||||
| 'header_name': header_names, | |||||
| 'rows': rows | |||||
| } | |||||
| except Exception: | except Exception: | ||||
| tabledata = {'headers': [], 'cells': []} | |||||
| tabledata = {'header_id': [], 'header_name': [], 'rows': []} | |||||
| else: | else: | ||||
| tabledata = {'headers': [], 'cells': []} | |||||
| tabledata = {'header_id': [], 'header_name': [], 'rows': []} | |||||
| output = { | output = { | ||||
| OutputKeys.SQL_STRING: sql.string, | OutputKeys.SQL_STRING: sql.string, | ||||
| OutputKeys.SQL_QUERY: sql.query, | OutputKeys.SQL_QUERY: sql.query, | ||||
| OutputKeys.HISTORY: result['sql'], | OutputKeys.HISTORY: result['sql'], | ||||
| OutputKeys.QUERT_RESULT: json.dumps(tabledata, ensure_ascii=False), | |||||
| OutputKeys.QUERT_RESULT: tabledata, | |||||
| } | } | ||||
| return output | return output | ||||
| @@ -13,7 +13,7 @@ class Database: | |||||
| tokenizer, | tokenizer, | ||||
| table_file_path, | table_file_path, | ||||
| syn_dict_file_path, | syn_dict_file_path, | ||||
| is_use_sqlite=False): | |||||
| is_use_sqlite=True): | |||||
| self.tokenizer = tokenizer | self.tokenizer = tokenizer | ||||
| self.is_use_sqlite = is_use_sqlite | self.is_use_sqlite = is_use_sqlite | ||||
| if self.is_use_sqlite: | if self.is_use_sqlite: | ||||
| @@ -293,6 +293,7 @@ class SchemaLinker: | |||||
| nlu_t, | nlu_t, | ||||
| tables, | tables, | ||||
| col_syn_dict, | col_syn_dict, | ||||
| table_id=None, | |||||
| history_sql=None): | history_sql=None): | ||||
| """ | """ | ||||
| get linking between question and schema column | get linking between question and schema column | ||||
| @@ -300,6 +301,9 @@ class SchemaLinker: | |||||
| typeinfos = [] | typeinfos = [] | ||||
| numbers = re.findall(r'[-]?\d*\.\d+|[-]?\d+|\d+', nlu) | numbers = re.findall(r'[-]?\d*\.\d+|[-]?\d+|\d+', nlu) | ||||
| if table_id is not None and table_id in tables: | |||||
| tables = {table_id: tables[table_id]} | |||||
| # search schema link in every table | # search schema link in every table | ||||
| search_result_list = [] | search_result_list = [] | ||||
| for tablename in tables: | for tablename in tables: | ||||
| @@ -411,26 +415,25 @@ class SchemaLinker: | |||||
| # get the match score of each table | # get the match score of each table | ||||
| match_score = self.get_table_match_score(nlu_t, schema_link) | match_score = self.get_table_match_score(nlu_t, schema_link) | ||||
| # cal table_score | |||||
| if history_sql is not None and 'from' in history_sql: | |||||
| table_score = int(table['table_id'] == history_sql['from'][0]) | |||||
| else: | |||||
| table_score = 0 | |||||
| search_result = { | search_result = { | ||||
| 'table_id': | |||||
| table['table_id'], | |||||
| 'question_knowledge': | |||||
| final_question, | |||||
| 'header_knowledge': | |||||
| final_header, | |||||
| 'schema_link': | |||||
| schema_link, | |||||
| 'match_score': | |||||
| match_score, | |||||
| 'table_score': | |||||
| int(table['table_id'] == history_sql['from'][0]) | |||||
| if history_sql is not None else 0 | |||||
| 'table_id': table['table_id'], | |||||
| 'question_knowledge': final_question, | |||||
| 'header_knowledge': final_header, | |||||
| 'schema_link': schema_link, | |||||
| 'match_score': match_score, | |||||
| 'table_score': table_score | |||||
| } | } | ||||
| search_result_list.append(search_result) | search_result_list.append(search_result) | ||||
| search_result_list = sorted( | search_result_list = sorted( | ||||
| search_result_list, | search_result_list, | ||||
| key=lambda x: (x['match_score'], x['table_score']), | key=lambda x: (x['match_score'], x['table_score']), | ||||
| reverse=True)[0:4] | |||||
| reverse=True)[0:1] | |||||
| return search_result_list | return search_result_list | ||||
| @@ -95,6 +95,7 @@ class TableQuestionAnsweringPreprocessor(Preprocessor): | |||||
| # tokenize question | # tokenize question | ||||
| question = data['question'] | question = data['question'] | ||||
| table_id = data.get('table_id', None) | |||||
| history_sql = data.get('history_sql', None) | history_sql = data.get('history_sql', None) | ||||
| nlu = question.lower() | nlu = question.lower() | ||||
| nlu_t = self.tokenizer.tokenize(nlu) | nlu_t = self.tokenizer.tokenize(nlu) | ||||
| @@ -106,6 +107,7 @@ class TableQuestionAnsweringPreprocessor(Preprocessor): | |||||
| nlu_t=nlu_t, | nlu_t=nlu_t, | ||||
| tables=self.db.tables, | tables=self.db.tables, | ||||
| col_syn_dict=self.db.syn_dict, | col_syn_dict=self.db.syn_dict, | ||||
| table_id=table_id, | |||||
| history_sql=history_sql) | history_sql=history_sql) | ||||
| # collect data | # collect data | ||||
| @@ -43,7 +43,7 @@ def tableqa_tracking_and_print_results_with_history( | |||||
| print('sql text:', output_dict[OutputKeys.SQL_STRING]) | print('sql text:', output_dict[OutputKeys.SQL_STRING]) | ||||
| print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | ||||
| print('query result:', output_dict[OutputKeys.QUERT_RESULT]) | print('query result:', output_dict[OutputKeys.QUERT_RESULT]) | ||||
| print('json dumps', json.dumps(output_dict)) | |||||
| print('json dumps', json.dumps(output_dict, ensure_ascii=False)) | |||||
| print() | print() | ||||
| historical_queries = output_dict[OutputKeys.HISTORY] | historical_queries = output_dict[OutputKeys.HISTORY] | ||||
| @@ -66,10 +66,42 @@ def tableqa_tracking_and_print_results_without_history( | |||||
| print('sql text:', output_dict[OutputKeys.SQL_STRING]) | print('sql text:', output_dict[OutputKeys.SQL_STRING]) | ||||
| print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | ||||
| print('query result:', output_dict[OutputKeys.QUERT_RESULT]) | print('query result:', output_dict[OutputKeys.QUERT_RESULT]) | ||||
| print('json dumps', json.dumps(output_dict)) | |||||
| print('json dumps', json.dumps(output_dict, ensure_ascii=False)) | |||||
| print() | print() | ||||
| def tableqa_tracking_and_print_results_with_tableid( | |||||
| pipelines: List[TableQuestionAnsweringPipeline]): | |||||
| test_case = { | |||||
| 'utterance': [ | |||||
| ['有哪些风险类型?', 'fund'], | |||||
| ['风险类型有多少种?', 'reservoir'], | |||||
| ['珠江流域的小(2)型水库的库容总量是多少?', 'reservoir'], | |||||
| ['那平均值是多少?', 'reservoir'], | |||||
| ['那水库的名称呢?', 'reservoir'], | |||||
| ['换成中型的呢?', 'reservoir'], | |||||
| ['枣庄营业厅的电话', 'business'], | |||||
| ['那地址呢?', 'business'], | |||||
| ['枣庄营业厅的电话和地址', 'business'], | |||||
| ], | |||||
| } | |||||
| for p in pipelines: | |||||
| historical_queries = None | |||||
| for question, table_id in test_case['utterance']: | |||||
| output_dict = p({ | |||||
| 'question': question, | |||||
| 'table_id': table_id, | |||||
| 'history_sql': historical_queries | |||||
| }) | |||||
| print('question', question) | |||||
| print('sql text:', output_dict[OutputKeys.SQL_STRING]) | |||||
| print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | |||||
| print('query result:', output_dict[OutputKeys.QUERT_RESULT]) | |||||
| print('json dumps', json.dumps(output_dict, ensure_ascii=False)) | |||||
| print() | |||||
| historical_queries = output_dict[OutputKeys.HISTORY] | |||||
| class TableQuestionAnswering(unittest.TestCase): | class TableQuestionAnswering(unittest.TestCase): | ||||
| def setUp(self) -> None: | def setUp(self) -> None: | ||||
| @@ -93,15 +125,27 @@ class TableQuestionAnswering(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_run_with_model_from_modelhub(self): | def test_run_with_model_from_modelhub(self): | ||||
| model = Model.from_pretrained(self.model_id) | model = Model.from_pretrained(self.model_id) | ||||
| self.tokenizer = BertTokenizer( | |||||
| os.path.join(model.model_dir, ModelFile.VOCAB_FILE)) | |||||
| db = Database( | |||||
| tokenizer=self.tokenizer, | |||||
| table_file_path=[ | |||||
| os.path.join(model.model_dir, 'databases', fname) | |||||
| for fname in os.listdir( | |||||
| os.path.join(model.model_dir, 'databases')) | |||||
| ], | |||||
| syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'), | |||||
| is_use_sqlite=False) | |||||
| preprocessor = TableQuestionAnsweringPreprocessor( | preprocessor = TableQuestionAnsweringPreprocessor( | ||||
| model_dir=model.model_dir) | |||||
| model_dir=model.model_dir, db=db) | |||||
| pipelines = [ | pipelines = [ | ||||
| pipeline( | pipeline( | ||||
| Tasks.table_question_answering, | Tasks.table_question_answering, | ||||
| model=model, | model=model, | ||||
| preprocessor=preprocessor) | |||||
| preprocessor=preprocessor, | |||||
| db=db) | |||||
| ] | ] | ||||
| tableqa_tracking_and_print_results_with_history(pipelines) | |||||
| tableqa_tracking_and_print_results_with_tableid(pipelines) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_run_with_model_from_task(self): | def test_run_with_model_from_task(self): | ||||
| @@ -132,7 +176,6 @@ class TableQuestionAnswering(unittest.TestCase): | |||||
| db=db) | db=db) | ||||
| ] | ] | ||||
| tableqa_tracking_and_print_results_without_history(pipelines) | tableqa_tracking_and_print_results_without_history(pipelines) | ||||
| tableqa_tracking_and_print_results_with_history(pipelines) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||