diff --git a/modelscope/pipelines/nlp/table_question_answering_pipeline.py b/modelscope/pipelines/nlp/table_question_answering_pipeline.py index ca17c9b1..08501953 100644 --- a/modelscope/pipelines/nlp/table_question_answering_pipeline.py +++ b/modelscope/pipelines/nlp/table_question_answering_pipeline.py @@ -72,6 +72,7 @@ class TableQuestionAnsweringPipeline(Pipeline): action = self.action_ops[result['action']] headers = table['header_name'] current_sql = result['sql'] + current_sql['from'] = [table['table_id']] if history_sql is None: return current_sql @@ -216,10 +217,11 @@ class TableQuestionAnsweringPipeline(Pipeline): else: return current_sql - def sql_dict_to_str(self, result, table): + def sql_dict_to_str(self, result, tables): """ convert sql struct to string """ + table = tables[result['sql']['from'][0]] header_names = table['header_name'] + ['空列'] header_ids = table['header_id'] + ['null'] sql = result['sql'] @@ -279,42 +281,43 @@ class TableQuestionAnsweringPipeline(Pipeline): """ result = inputs['result'] history_sql = inputs['history_sql'] - result['sql'] = self.post_process_multi_turn( - history_sql=history_sql, - result=result, - table=self.db.tables[result['table_id']]) - result['sql']['from'] = [result['table_id']] - sql = self.sql_dict_to_str( - result=result, table=self.db.tables[result['table_id']]) + try: + result['sql'] = self.post_process_multi_turn( + history_sql=history_sql, + result=result, + table=self.db.tables[result['table_id']]) + except Exception: + result['sql'] = history_sql + sql = self.sql_dict_to_str(result=result, tables=self.db.tables) # add sqlite if self.db.is_use_sqlite: try: cursor = self.db.connection_obj.cursor().execute(sql.query) - names = [{ - 'name': - description[0], - 'label': - self.db.tables[result['table_id']]['headerid2name'].get( - description[0], description[0]) - } for description in cursor.description] - cells = [] + header_ids, header_names = [], [] + for description in cursor.description: + header_ids.append(self.db.tables[result['table_id']] + ['headerid2name'].get( + description[0], description[0])) + header_names.append(description[0]) + rows = [] for res in cursor.fetchall(): - row = {} - for name, cell in zip(names, res): - row[name['name']] = cell - cells.append(row) - tabledata = {'headers': names, 'cells': cells} + rows.append(list(res)) + tabledata = { + 'header_id': header_ids, + 'header_name': header_names, + 'rows': rows + } except Exception: - tabledata = {'headers': [], 'cells': []} + tabledata = {'header_id': [], 'header_name': [], 'rows': []} else: - tabledata = {'headers': [], 'cells': []} + tabledata = {'header_id': [], 'header_name': [], 'rows': []} output = { OutputKeys.SQL_STRING: sql.string, OutputKeys.SQL_QUERY: sql.query, OutputKeys.HISTORY: result['sql'], - OutputKeys.QUERT_RESULT: json.dumps(tabledata, ensure_ascii=False), + OutputKeys.QUERT_RESULT: tabledata, } return output diff --git a/modelscope/preprocessors/star3/fields/database.py b/modelscope/preprocessors/star3/fields/database.py index 3d3a1f8d..5debfe2c 100644 --- a/modelscope/preprocessors/star3/fields/database.py +++ b/modelscope/preprocessors/star3/fields/database.py @@ -13,7 +13,7 @@ class Database: tokenizer, table_file_path, syn_dict_file_path, - is_use_sqlite=False): + is_use_sqlite=True): self.tokenizer = tokenizer self.is_use_sqlite = is_use_sqlite if self.is_use_sqlite: diff --git a/modelscope/preprocessors/star3/fields/schema_link.py b/modelscope/preprocessors/star3/fields/schema_link.py index 7f483a1f..220a71d8 100644 --- a/modelscope/preprocessors/star3/fields/schema_link.py +++ b/modelscope/preprocessors/star3/fields/schema_link.py @@ -293,6 +293,7 @@ class SchemaLinker: nlu_t, tables, col_syn_dict, + table_id=None, history_sql=None): """ get linking between question and schema column @@ -300,6 +301,9 @@ class SchemaLinker: typeinfos = [] numbers = re.findall(r'[-]?\d*\.\d+|[-]?\d+|\d+', nlu) + if table_id is not None and table_id in tables: + tables = {table_id: tables[table_id]} + # search schema link in every table search_result_list = [] for tablename in tables: @@ -411,26 +415,25 @@ class SchemaLinker: # get the match score of each table match_score = self.get_table_match_score(nlu_t, schema_link) + # cal table_score + if history_sql is not None and 'from' in history_sql: + table_score = int(table['table_id'] == history_sql['from'][0]) + else: + table_score = 0 + search_result = { - 'table_id': - table['table_id'], - 'question_knowledge': - final_question, - 'header_knowledge': - final_header, - 'schema_link': - schema_link, - 'match_score': - match_score, - 'table_score': - int(table['table_id'] == history_sql['from'][0]) - if history_sql is not None else 0 + 'table_id': table['table_id'], + 'question_knowledge': final_question, + 'header_knowledge': final_header, + 'schema_link': schema_link, + 'match_score': match_score, + 'table_score': table_score } search_result_list.append(search_result) search_result_list = sorted( search_result_list, key=lambda x: (x['match_score'], x['table_score']), - reverse=True)[0:4] + reverse=True)[0:1] return search_result_list diff --git a/modelscope/preprocessors/star3/table_question_answering_preprocessor.py b/modelscope/preprocessors/star3/table_question_answering_preprocessor.py index f98aa6d0..ed2911f6 100644 --- a/modelscope/preprocessors/star3/table_question_answering_preprocessor.py +++ b/modelscope/preprocessors/star3/table_question_answering_preprocessor.py @@ -95,6 +95,7 @@ class TableQuestionAnsweringPreprocessor(Preprocessor): # tokenize question question = data['question'] + table_id = data.get('table_id', None) history_sql = data.get('history_sql', None) nlu = question.lower() nlu_t = self.tokenizer.tokenize(nlu) @@ -106,6 +107,7 @@ class TableQuestionAnsweringPreprocessor(Preprocessor): nlu_t=nlu_t, tables=self.db.tables, col_syn_dict=self.db.syn_dict, + table_id=table_id, history_sql=history_sql) # collect data diff --git a/tests/pipelines/test_table_question_answering.py b/tests/pipelines/test_table_question_answering.py index 3d943e51..571ca795 100644 --- a/tests/pipelines/test_table_question_answering.py +++ b/tests/pipelines/test_table_question_answering.py @@ -43,7 +43,7 @@ def tableqa_tracking_and_print_results_with_history( print('sql text:', output_dict[OutputKeys.SQL_STRING]) print('sql query:', output_dict[OutputKeys.SQL_QUERY]) print('query result:', output_dict[OutputKeys.QUERT_RESULT]) - print('json dumps', json.dumps(output_dict)) + print('json dumps', json.dumps(output_dict, ensure_ascii=False)) print() historical_queries = output_dict[OutputKeys.HISTORY] @@ -66,10 +66,42 @@ def tableqa_tracking_and_print_results_without_history( print('sql text:', output_dict[OutputKeys.SQL_STRING]) print('sql query:', output_dict[OutputKeys.SQL_QUERY]) print('query result:', output_dict[OutputKeys.QUERT_RESULT]) - print('json dumps', json.dumps(output_dict)) + print('json dumps', json.dumps(output_dict, ensure_ascii=False)) print() +def tableqa_tracking_and_print_results_with_tableid( + pipelines: List[TableQuestionAnsweringPipeline]): + test_case = { + 'utterance': [ + ['有哪些风险类型?', 'fund'], + ['风险类型有多少种?', 'reservoir'], + ['珠江流域的小(2)型水库的库容总量是多少?', 'reservoir'], + ['那平均值是多少?', 'reservoir'], + ['那水库的名称呢?', 'reservoir'], + ['换成中型的呢?', 'reservoir'], + ['枣庄营业厅的电话', 'business'], + ['那地址呢?', 'business'], + ['枣庄营业厅的电话和地址', 'business'], + ], + } + for p in pipelines: + historical_queries = None + for question, table_id in test_case['utterance']: + output_dict = p({ + 'question': question, + 'table_id': table_id, + 'history_sql': historical_queries + }) + print('question', question) + print('sql text:', output_dict[OutputKeys.SQL_STRING]) + print('sql query:', output_dict[OutputKeys.SQL_QUERY]) + print('query result:', output_dict[OutputKeys.QUERT_RESULT]) + print('json dumps', json.dumps(output_dict, ensure_ascii=False)) + print() + historical_queries = output_dict[OutputKeys.HISTORY] + + class TableQuestionAnswering(unittest.TestCase): def setUp(self) -> None: @@ -93,15 +125,27 @@ class TableQuestionAnswering(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_model_from_modelhub(self): model = Model.from_pretrained(self.model_id) + self.tokenizer = BertTokenizer( + os.path.join(model.model_dir, ModelFile.VOCAB_FILE)) + db = Database( + tokenizer=self.tokenizer, + table_file_path=[ + os.path.join(model.model_dir, 'databases', fname) + for fname in os.listdir( + os.path.join(model.model_dir, 'databases')) + ], + syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'), + is_use_sqlite=False) preprocessor = TableQuestionAnsweringPreprocessor( - model_dir=model.model_dir) + model_dir=model.model_dir, db=db) pipelines = [ pipeline( Tasks.table_question_answering, model=model, - preprocessor=preprocessor) + preprocessor=preprocessor, + db=db) ] - tableqa_tracking_and_print_results_with_history(pipelines) + tableqa_tracking_and_print_results_with_tableid(pipelines) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_from_task(self): @@ -132,7 +176,6 @@ class TableQuestionAnswering(unittest.TestCase): db=db) ] tableqa_tracking_and_print_results_without_history(pipelines) - tableqa_tracking_and_print_results_with_history(pipelines) if __name__ == '__main__':