diff --git a/modelscope/models/nlp/table_question_answering.py b/modelscope/models/nlp/table_question_answering.py index c6a03ef3..c2134df2 100644 --- a/modelscope/models/nlp/table_question_answering.py +++ b/modelscope/models/nlp/table_question_answering.py @@ -691,11 +691,11 @@ class TableQuestionAnswering(Model): sels.append(l_hs[ib] - 1) aggs.append(sql['agg'][ia]) continue - sels.append(sel) + sels.append(int(sel)) if sql['agg'][ia] == -1: aggs.append(0) else: - aggs.append(sql['agg'][ia]) + aggs.append(int(sql['agg'][ia])) if len(sels) == 0: sels.append(l_hs[ib] - 1) aggs.append(0) @@ -712,7 +712,7 @@ class TableQuestionAnswering(Model): for i in range(wl): if wc_os[i] == -1: continue - conds.append([wc_os[i], wo_os[i], pr_wvi_str[ib][i]]) + conds.append([int(wc_os[i]), int(wo_os[i]), pr_wvi_str[ib][i]]) if len(conds) == 0: conds.append([l_hs[ib] - 1, 2, 'Nulll']) sql['conds'] = conds diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 3001c03c..c08779b4 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -36,6 +36,8 @@ class OutputKeys(object): UUID = 'uuid' WORD = 'word' KWS_LIST = 'kws_list' + SQL_STRING = 'sql_string' + SQL_QUERY = 'sql_query' HISTORY = 'history' QUERT_RESULT = 'query_result' TIMESTAMPS = 'timestamps' @@ -583,7 +585,10 @@ TASK_OUTPUTS = { # "sql": "SELECT shop.Name FROM shop." # "sql_history": {sel: 0, agg: 0, conds: [[0, 0, 'val']]} # } - Tasks.table_question_answering: [OutputKeys.OUTPUT, OutputKeys.HISTORY], + Tasks.table_question_answering: [ + OutputKeys.SQL_STRING, OutputKeys.SQL_QUERY, OutputKeys.HISTORY, + OutputKeys.QUERT_RESULT + ], # ============ audio tasks =================== # asr result for single sample diff --git a/modelscope/pipelines/nlp/table_question_answering_pipeline.py b/modelscope/pipelines/nlp/table_question_answering_pipeline.py index e1b2b07b..ca17c9b1 100644 --- a/modelscope/pipelines/nlp/table_question_answering_pipeline.py +++ b/modelscope/pipelines/nlp/table_question_answering_pipeline.py @@ -311,7 +311,8 @@ class TableQuestionAnsweringPipeline(Pipeline): tabledata = {'headers': [], 'cells': []} output = { - OutputKeys.OUTPUT: sql, + OutputKeys.SQL_STRING: sql.string, + OutputKeys.SQL_QUERY: sql.query, OutputKeys.HISTORY: result['sql'], OutputKeys.QUERT_RESULT: json.dumps(tabledata, ensure_ascii=False), } diff --git a/tests/pipelines/test_table_question_answering.py b/tests/pipelines/test_table_question_answering.py index 68e0564f..3d943e51 100644 --- a/tests/pipelines/test_table_question_answering.py +++ b/tests/pipelines/test_table_question_answering.py @@ -3,10 +3,12 @@ import os import unittest from typing import List +import json from transformers import BertTokenizer from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model +from modelscope.outputs import OutputKeys from modelscope.pipelines import pipeline from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline from modelscope.preprocessors import TableQuestionAnsweringPreprocessor @@ -38,11 +40,12 @@ def tableqa_tracking_and_print_results_with_history( 'history_sql': historical_queries }) print('question', question) - print('sql text:', output_dict['output'].string) - print('sql query:', output_dict['output'].query) - print('query result:', output_dict['query_result']) + print('sql text:', output_dict[OutputKeys.SQL_STRING]) + print('sql query:', output_dict[OutputKeys.SQL_QUERY]) + print('query result:', output_dict[OutputKeys.QUERT_RESULT]) + print('json dumps', json.dumps(output_dict)) print() - historical_queries = output_dict['history'] + historical_queries = output_dict[OutputKeys.HISTORY] def tableqa_tracking_and_print_results_without_history( @@ -60,9 +63,10 @@ def tableqa_tracking_and_print_results_without_history( for question in test_case['utterance']: output_dict = p({'question': question}) print('question', question) - print('sql text:', output_dict['output'].string) - print('sql query:', output_dict['output'].query) - print('query result:', output_dict['query_result']) + print('sql text:', output_dict[OutputKeys.SQL_STRING]) + print('sql query:', output_dict[OutputKeys.SQL_QUERY]) + print('query result:', output_dict[OutputKeys.QUERT_RESULT]) + print('json dumps', json.dumps(output_dict)) print()