1. 增加传入table_id
2. 将result和table的结构统一
3. 默认开启is_use_sqlite
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10492027
master
| @@ -72,6 +72,7 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||
| action = self.action_ops[result['action']] | |||
| headers = table['header_name'] | |||
| current_sql = result['sql'] | |||
| current_sql['from'] = [table['table_id']] | |||
| if history_sql is None: | |||
| return current_sql | |||
| @@ -216,10 +217,11 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||
| else: | |||
| return current_sql | |||
| def sql_dict_to_str(self, result, table): | |||
| def sql_dict_to_str(self, result, tables): | |||
| """ | |||
| convert sql struct to string | |||
| """ | |||
| table = tables[result['sql']['from'][0]] | |||
| header_names = table['header_name'] + ['空列'] | |||
| header_ids = table['header_id'] + ['null'] | |||
| sql = result['sql'] | |||
| @@ -279,42 +281,43 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||
| """ | |||
| result = inputs['result'] | |||
| history_sql = inputs['history_sql'] | |||
| result['sql'] = self.post_process_multi_turn( | |||
| history_sql=history_sql, | |||
| result=result, | |||
| table=self.db.tables[result['table_id']]) | |||
| result['sql']['from'] = [result['table_id']] | |||
| sql = self.sql_dict_to_str( | |||
| result=result, table=self.db.tables[result['table_id']]) | |||
| try: | |||
| result['sql'] = self.post_process_multi_turn( | |||
| history_sql=history_sql, | |||
| result=result, | |||
| table=self.db.tables[result['table_id']]) | |||
| except Exception: | |||
| result['sql'] = history_sql | |||
| sql = self.sql_dict_to_str(result=result, tables=self.db.tables) | |||
| # add sqlite | |||
| if self.db.is_use_sqlite: | |||
| try: | |||
| cursor = self.db.connection_obj.cursor().execute(sql.query) | |||
| names = [{ | |||
| 'name': | |||
| description[0], | |||
| 'label': | |||
| self.db.tables[result['table_id']]['headerid2name'].get( | |||
| description[0], description[0]) | |||
| } for description in cursor.description] | |||
| cells = [] | |||
| header_ids, header_names = [], [] | |||
| for description in cursor.description: | |||
| header_ids.append(self.db.tables[result['table_id']] | |||
| ['headerid2name'].get( | |||
| description[0], description[0])) | |||
| header_names.append(description[0]) | |||
| rows = [] | |||
| for res in cursor.fetchall(): | |||
| row = {} | |||
| for name, cell in zip(names, res): | |||
| row[name['name']] = cell | |||
| cells.append(row) | |||
| tabledata = {'headers': names, 'cells': cells} | |||
| rows.append(list(res)) | |||
| tabledata = { | |||
| 'header_id': header_ids, | |||
| 'header_name': header_names, | |||
| 'rows': rows | |||
| } | |||
| except Exception: | |||
| tabledata = {'headers': [], 'cells': []} | |||
| tabledata = {'header_id': [], 'header_name': [], 'rows': []} | |||
| else: | |||
| tabledata = {'headers': [], 'cells': []} | |||
| tabledata = {'header_id': [], 'header_name': [], 'rows': []} | |||
| output = { | |||
| OutputKeys.SQL_STRING: sql.string, | |||
| OutputKeys.SQL_QUERY: sql.query, | |||
| OutputKeys.HISTORY: result['sql'], | |||
| OutputKeys.QUERT_RESULT: json.dumps(tabledata, ensure_ascii=False), | |||
| OutputKeys.QUERT_RESULT: tabledata, | |||
| } | |||
| return output | |||
| @@ -13,7 +13,7 @@ class Database: | |||
| tokenizer, | |||
| table_file_path, | |||
| syn_dict_file_path, | |||
| is_use_sqlite=False): | |||
| is_use_sqlite=True): | |||
| self.tokenizer = tokenizer | |||
| self.is_use_sqlite = is_use_sqlite | |||
| if self.is_use_sqlite: | |||
| @@ -293,6 +293,7 @@ class SchemaLinker: | |||
| nlu_t, | |||
| tables, | |||
| col_syn_dict, | |||
| table_id=None, | |||
| history_sql=None): | |||
| """ | |||
| get linking between question and schema column | |||
| @@ -300,6 +301,9 @@ class SchemaLinker: | |||
| typeinfos = [] | |||
| numbers = re.findall(r'[-]?\d*\.\d+|[-]?\d+|\d+', nlu) | |||
| if table_id is not None and table_id in tables: | |||
| tables = {table_id: tables[table_id]} | |||
| # search schema link in every table | |||
| search_result_list = [] | |||
| for tablename in tables: | |||
| @@ -411,26 +415,25 @@ class SchemaLinker: | |||
| # get the match score of each table | |||
| match_score = self.get_table_match_score(nlu_t, schema_link) | |||
| # cal table_score | |||
| if history_sql is not None and 'from' in history_sql: | |||
| table_score = int(table['table_id'] == history_sql['from'][0]) | |||
| else: | |||
| table_score = 0 | |||
| search_result = { | |||
| 'table_id': | |||
| table['table_id'], | |||
| 'question_knowledge': | |||
| final_question, | |||
| 'header_knowledge': | |||
| final_header, | |||
| 'schema_link': | |||
| schema_link, | |||
| 'match_score': | |||
| match_score, | |||
| 'table_score': | |||
| int(table['table_id'] == history_sql['from'][0]) | |||
| if history_sql is not None else 0 | |||
| 'table_id': table['table_id'], | |||
| 'question_knowledge': final_question, | |||
| 'header_knowledge': final_header, | |||
| 'schema_link': schema_link, | |||
| 'match_score': match_score, | |||
| 'table_score': table_score | |||
| } | |||
| search_result_list.append(search_result) | |||
| search_result_list = sorted( | |||
| search_result_list, | |||
| key=lambda x: (x['match_score'], x['table_score']), | |||
| reverse=True)[0:4] | |||
| reverse=True)[0:1] | |||
| return search_result_list | |||
| @@ -95,6 +95,7 @@ class TableQuestionAnsweringPreprocessor(Preprocessor): | |||
| # tokenize question | |||
| question = data['question'] | |||
| table_id = data.get('table_id', None) | |||
| history_sql = data.get('history_sql', None) | |||
| nlu = question.lower() | |||
| nlu_t = self.tokenizer.tokenize(nlu) | |||
| @@ -106,6 +107,7 @@ class TableQuestionAnsweringPreprocessor(Preprocessor): | |||
| nlu_t=nlu_t, | |||
| tables=self.db.tables, | |||
| col_syn_dict=self.db.syn_dict, | |||
| table_id=table_id, | |||
| history_sql=history_sql) | |||
| # collect data | |||
| @@ -43,7 +43,7 @@ def tableqa_tracking_and_print_results_with_history( | |||
| print('sql text:', output_dict[OutputKeys.SQL_STRING]) | |||
| print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | |||
| print('query result:', output_dict[OutputKeys.QUERT_RESULT]) | |||
| print('json dumps', json.dumps(output_dict)) | |||
| print('json dumps', json.dumps(output_dict, ensure_ascii=False)) | |||
| print() | |||
| historical_queries = output_dict[OutputKeys.HISTORY] | |||
| @@ -66,10 +66,42 @@ def tableqa_tracking_and_print_results_without_history( | |||
| print('sql text:', output_dict[OutputKeys.SQL_STRING]) | |||
| print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | |||
| print('query result:', output_dict[OutputKeys.QUERT_RESULT]) | |||
| print('json dumps', json.dumps(output_dict)) | |||
| print('json dumps', json.dumps(output_dict, ensure_ascii=False)) | |||
| print() | |||
| def tableqa_tracking_and_print_results_with_tableid( | |||
| pipelines: List[TableQuestionAnsweringPipeline]): | |||
| test_case = { | |||
| 'utterance': [ | |||
| ['有哪些风险类型?', 'fund'], | |||
| ['风险类型有多少种?', 'reservoir'], | |||
| ['珠江流域的小(2)型水库的库容总量是多少?', 'reservoir'], | |||
| ['那平均值是多少?', 'reservoir'], | |||
| ['那水库的名称呢?', 'reservoir'], | |||
| ['换成中型的呢?', 'reservoir'], | |||
| ['枣庄营业厅的电话', 'business'], | |||
| ['那地址呢?', 'business'], | |||
| ['枣庄营业厅的电话和地址', 'business'], | |||
| ], | |||
| } | |||
| for p in pipelines: | |||
| historical_queries = None | |||
| for question, table_id in test_case['utterance']: | |||
| output_dict = p({ | |||
| 'question': question, | |||
| 'table_id': table_id, | |||
| 'history_sql': historical_queries | |||
| }) | |||
| print('question', question) | |||
| print('sql text:', output_dict[OutputKeys.SQL_STRING]) | |||
| print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | |||
| print('query result:', output_dict[OutputKeys.QUERT_RESULT]) | |||
| print('json dumps', json.dumps(output_dict, ensure_ascii=False)) | |||
| print() | |||
| historical_queries = output_dict[OutputKeys.HISTORY] | |||
| class TableQuestionAnswering(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| @@ -93,15 +125,27 @@ class TableQuestionAnswering(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| self.tokenizer = BertTokenizer( | |||
| os.path.join(model.model_dir, ModelFile.VOCAB_FILE)) | |||
| db = Database( | |||
| tokenizer=self.tokenizer, | |||
| table_file_path=[ | |||
| os.path.join(model.model_dir, 'databases', fname) | |||
| for fname in os.listdir( | |||
| os.path.join(model.model_dir, 'databases')) | |||
| ], | |||
| syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'), | |||
| is_use_sqlite=False) | |||
| preprocessor = TableQuestionAnsweringPreprocessor( | |||
| model_dir=model.model_dir) | |||
| model_dir=model.model_dir, db=db) | |||
| pipelines = [ | |||
| pipeline( | |||
| Tasks.table_question_answering, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| preprocessor=preprocessor, | |||
| db=db) | |||
| ] | |||
| tableqa_tracking_and_print_results_with_history(pipelines) | |||
| tableqa_tracking_and_print_results_with_tableid(pipelines) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_task(self): | |||
| @@ -132,7 +176,6 @@ class TableQuestionAnswering(unittest.TestCase): | |||
| db=db) | |||
| ] | |||
| tableqa_tracking_and_print_results_without_history(pipelines) | |||
| tableqa_tracking_and_print_results_with_history(pipelines) | |||
| if __name__ == '__main__': | |||