修改output的结构,直接返回可转化成json format的结构
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10415403
master
| @@ -691,11 +691,11 @@ class TableQuestionAnswering(Model): | |||||
| sels.append(l_hs[ib] - 1) | sels.append(l_hs[ib] - 1) | ||||
| aggs.append(sql['agg'][ia]) | aggs.append(sql['agg'][ia]) | ||||
| continue | continue | ||||
| sels.append(sel) | |||||
| sels.append(int(sel)) | |||||
| if sql['agg'][ia] == -1: | if sql['agg'][ia] == -1: | ||||
| aggs.append(0) | aggs.append(0) | ||||
| else: | else: | ||||
| aggs.append(sql['agg'][ia]) | |||||
| aggs.append(int(sql['agg'][ia])) | |||||
| if len(sels) == 0: | if len(sels) == 0: | ||||
| sels.append(l_hs[ib] - 1) | sels.append(l_hs[ib] - 1) | ||||
| aggs.append(0) | aggs.append(0) | ||||
| @@ -712,7 +712,7 @@ class TableQuestionAnswering(Model): | |||||
| for i in range(wl): | for i in range(wl): | ||||
| if wc_os[i] == -1: | if wc_os[i] == -1: | ||||
| continue | continue | ||||
| conds.append([wc_os[i], wo_os[i], pr_wvi_str[ib][i]]) | |||||
| conds.append([int(wc_os[i]), int(wo_os[i]), pr_wvi_str[ib][i]]) | |||||
| if len(conds) == 0: | if len(conds) == 0: | ||||
| conds.append([l_hs[ib] - 1, 2, 'Nulll']) | conds.append([l_hs[ib] - 1, 2, 'Nulll']) | ||||
| sql['conds'] = conds | sql['conds'] = conds | ||||
| @@ -36,6 +36,8 @@ class OutputKeys(object): | |||||
| UUID = 'uuid' | UUID = 'uuid' | ||||
| WORD = 'word' | WORD = 'word' | ||||
| KWS_LIST = 'kws_list' | KWS_LIST = 'kws_list' | ||||
| SQL_STRING = 'sql_string' | |||||
| SQL_QUERY = 'sql_query' | |||||
| HISTORY = 'history' | HISTORY = 'history' | ||||
| QUERT_RESULT = 'query_result' | QUERT_RESULT = 'query_result' | ||||
| TIMESTAMPS = 'timestamps' | TIMESTAMPS = 'timestamps' | ||||
| @@ -583,7 +585,10 @@ TASK_OUTPUTS = { | |||||
| # "sql": "SELECT shop.Name FROM shop." | # "sql": "SELECT shop.Name FROM shop." | ||||
| # "sql_history": {sel: 0, agg: 0, conds: [[0, 0, 'val']]} | # "sql_history": {sel: 0, agg: 0, conds: [[0, 0, 'val']]} | ||||
| # } | # } | ||||
| Tasks.table_question_answering: [OutputKeys.OUTPUT, OutputKeys.HISTORY], | |||||
| Tasks.table_question_answering: [ | |||||
| OutputKeys.SQL_STRING, OutputKeys.SQL_QUERY, OutputKeys.HISTORY, | |||||
| OutputKeys.QUERT_RESULT | |||||
| ], | |||||
| # ============ audio tasks =================== | # ============ audio tasks =================== | ||||
| # asr result for single sample | # asr result for single sample | ||||
| @@ -311,7 +311,8 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||||
| tabledata = {'headers': [], 'cells': []} | tabledata = {'headers': [], 'cells': []} | ||||
| output = { | output = { | ||||
| OutputKeys.OUTPUT: sql, | |||||
| OutputKeys.SQL_STRING: sql.string, | |||||
| OutputKeys.SQL_QUERY: sql.query, | |||||
| OutputKeys.HISTORY: result['sql'], | OutputKeys.HISTORY: result['sql'], | ||||
| OutputKeys.QUERT_RESULT: json.dumps(tabledata, ensure_ascii=False), | OutputKeys.QUERT_RESULT: json.dumps(tabledata, ensure_ascii=False), | ||||
| } | } | ||||
| @@ -3,10 +3,12 @@ import os | |||||
| import unittest | import unittest | ||||
| from typing import List | from typing import List | ||||
| import json | |||||
| from transformers import BertTokenizer | from transformers import BertTokenizer | ||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.models import Model | from modelscope.models import Model | ||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | from modelscope.pipelines import pipeline | ||||
| from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline | from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline | ||||
| from modelscope.preprocessors import TableQuestionAnsweringPreprocessor | from modelscope.preprocessors import TableQuestionAnsweringPreprocessor | ||||
| @@ -38,11 +40,12 @@ def tableqa_tracking_and_print_results_with_history( | |||||
| 'history_sql': historical_queries | 'history_sql': historical_queries | ||||
| }) | }) | ||||
| print('question', question) | print('question', question) | ||||
| print('sql text:', output_dict['output'].string) | |||||
| print('sql query:', output_dict['output'].query) | |||||
| print('query result:', output_dict['query_result']) | |||||
| print('sql text:', output_dict[OutputKeys.SQL_STRING]) | |||||
| print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | |||||
| print('query result:', output_dict[OutputKeys.QUERT_RESULT]) | |||||
| print('json dumps', json.dumps(output_dict)) | |||||
| print() | print() | ||||
| historical_queries = output_dict['history'] | |||||
| historical_queries = output_dict[OutputKeys.HISTORY] | |||||
| def tableqa_tracking_and_print_results_without_history( | def tableqa_tracking_and_print_results_without_history( | ||||
| @@ -60,9 +63,10 @@ def tableqa_tracking_and_print_results_without_history( | |||||
| for question in test_case['utterance']: | for question in test_case['utterance']: | ||||
| output_dict = p({'question': question}) | output_dict = p({'question': question}) | ||||
| print('question', question) | print('question', question) | ||||
| print('sql text:', output_dict['output'].string) | |||||
| print('sql query:', output_dict['output'].query) | |||||
| print('query result:', output_dict['query_result']) | |||||
| print('sql text:', output_dict[OutputKeys.SQL_STRING]) | |||||
| print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | |||||
| print('query result:', output_dict[OutputKeys.QUERT_RESULT]) | |||||
| print('json dumps', json.dumps(output_dict)) | |||||
| print() | print() | ||||