Browse Source

[to #42322933] change tableqa output

修改output的结构,直接返回可转化成json format的结构
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10415403
master
caorongyu.cry yingda.chen 3 years ago
parent
commit
202fcdf298
4 changed files with 22 additions and 12 deletions
  1. +3
    -3
      modelscope/models/nlp/table_question_answering.py
  2. +6
    -1
      modelscope/outputs.py
  3. +2
    -1
      modelscope/pipelines/nlp/table_question_answering_pipeline.py
  4. +11
    -7
      tests/pipelines/test_table_question_answering.py

+ 3
- 3
modelscope/models/nlp/table_question_answering.py View File

@@ -691,11 +691,11 @@ class TableQuestionAnswering(Model):
sels.append(l_hs[ib] - 1) sels.append(l_hs[ib] - 1)
aggs.append(sql['agg'][ia]) aggs.append(sql['agg'][ia])
continue continue
sels.append(sel)
sels.append(int(sel))
if sql['agg'][ia] == -1: if sql['agg'][ia] == -1:
aggs.append(0) aggs.append(0)
else: else:
aggs.append(sql['agg'][ia])
aggs.append(int(sql['agg'][ia]))
if len(sels) == 0: if len(sels) == 0:
sels.append(l_hs[ib] - 1) sels.append(l_hs[ib] - 1)
aggs.append(0) aggs.append(0)
@@ -712,7 +712,7 @@ class TableQuestionAnswering(Model):
for i in range(wl): for i in range(wl):
if wc_os[i] == -1: if wc_os[i] == -1:
continue continue
conds.append([wc_os[i], wo_os[i], pr_wvi_str[ib][i]])
conds.append([int(wc_os[i]), int(wo_os[i]), pr_wvi_str[ib][i]])
if len(conds) == 0: if len(conds) == 0:
conds.append([l_hs[ib] - 1, 2, 'Nulll']) conds.append([l_hs[ib] - 1, 2, 'Nulll'])
sql['conds'] = conds sql['conds'] = conds


+ 6
- 1
modelscope/outputs.py View File

@@ -36,6 +36,8 @@ class OutputKeys(object):
UUID = 'uuid' UUID = 'uuid'
WORD = 'word' WORD = 'word'
KWS_LIST = 'kws_list' KWS_LIST = 'kws_list'
SQL_STRING = 'sql_string'
SQL_QUERY = 'sql_query'
HISTORY = 'history' HISTORY = 'history'
QUERT_RESULT = 'query_result' QUERT_RESULT = 'query_result'
TIMESTAMPS = 'timestamps' TIMESTAMPS = 'timestamps'
@@ -583,7 +585,10 @@ TASK_OUTPUTS = {
# "sql": "SELECT shop.Name FROM shop." # "sql": "SELECT shop.Name FROM shop."
# "sql_history": {sel: 0, agg: 0, conds: [[0, 0, 'val']]} # "sql_history": {sel: 0, agg: 0, conds: [[0, 0, 'val']]}
# } # }
Tasks.table_question_answering: [OutputKeys.OUTPUT, OutputKeys.HISTORY],
Tasks.table_question_answering: [
OutputKeys.SQL_STRING, OutputKeys.SQL_QUERY, OutputKeys.HISTORY,
OutputKeys.QUERT_RESULT
],


# ============ audio tasks =================== # ============ audio tasks ===================
# asr result for single sample # asr result for single sample


+ 2
- 1
modelscope/pipelines/nlp/table_question_answering_pipeline.py View File

@@ -311,7 +311,8 @@ class TableQuestionAnsweringPipeline(Pipeline):
tabledata = {'headers': [], 'cells': []} tabledata = {'headers': [], 'cells': []}


output = { output = {
OutputKeys.OUTPUT: sql,
OutputKeys.SQL_STRING: sql.string,
OutputKeys.SQL_QUERY: sql.query,
OutputKeys.HISTORY: result['sql'], OutputKeys.HISTORY: result['sql'],
OutputKeys.QUERT_RESULT: json.dumps(tabledata, ensure_ascii=False), OutputKeys.QUERT_RESULT: json.dumps(tabledata, ensure_ascii=False),
} }


+ 11
- 7
tests/pipelines/test_table_question_answering.py View File

@@ -3,10 +3,12 @@ import os
import unittest import unittest
from typing import List from typing import List


import json
from transformers import BertTokenizer from transformers import BertTokenizer


from modelscope.hub.snapshot_download import snapshot_download from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline
from modelscope.preprocessors import TableQuestionAnsweringPreprocessor from modelscope.preprocessors import TableQuestionAnsweringPreprocessor
@@ -38,11 +40,12 @@ def tableqa_tracking_and_print_results_with_history(
'history_sql': historical_queries 'history_sql': historical_queries
}) })
print('question', question) print('question', question)
print('sql text:', output_dict['output'].string)
print('sql query:', output_dict['output'].query)
print('query result:', output_dict['query_result'])
print('sql text:', output_dict[OutputKeys.SQL_STRING])
print('sql query:', output_dict[OutputKeys.SQL_QUERY])
print('query result:', output_dict[OutputKeys.QUERT_RESULT])
print('json dumps', json.dumps(output_dict))
print() print()
historical_queries = output_dict['history']
historical_queries = output_dict[OutputKeys.HISTORY]




def tableqa_tracking_and_print_results_without_history( def tableqa_tracking_and_print_results_without_history(
@@ -60,9 +63,10 @@ def tableqa_tracking_and_print_results_without_history(
for question in test_case['utterance']: for question in test_case['utterance']:
output_dict = p({'question': question}) output_dict = p({'question': question})
print('question', question) print('question', question)
print('sql text:', output_dict['output'].string)
print('sql query:', output_dict['output'].query)
print('query result:', output_dict['query_result'])
print('sql text:', output_dict[OutputKeys.SQL_STRING])
print('sql query:', output_dict[OutputKeys.SQL_QUERY])
print('query result:', output_dict[OutputKeys.QUERT_RESULT])
print('json dumps', json.dumps(output_dict))
print() print()






Loading…
Cancel
Save