Browse Source

[to #42322933] update tableqa params

1. 增加传入table_id
2. 将result和table的结构统一
3. 默认开启is_use_sqlite
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10492027
master
caorongyu.cry yingda.chen 3 years ago
parent
commit
9edfd7e50c
5 changed files with 96 additions and 45 deletions
  1. +27
    -24
      modelscope/pipelines/nlp/table_question_answering_pipeline.py
  2. +1
    -1
      modelscope/preprocessors/star3/fields/database.py
  3. +17
    -14
      modelscope/preprocessors/star3/fields/schema_link.py
  4. +2
    -0
      modelscope/preprocessors/star3/table_question_answering_preprocessor.py
  5. +49
    -6
      tests/pipelines/test_table_question_answering.py

+ 27
- 24
modelscope/pipelines/nlp/table_question_answering_pipeline.py View File

@@ -72,6 +72,7 @@ class TableQuestionAnsweringPipeline(Pipeline):
action = self.action_ops[result['action']]
headers = table['header_name']
current_sql = result['sql']
current_sql['from'] = [table['table_id']]

if history_sql is None:
return current_sql
@@ -216,10 +217,11 @@ class TableQuestionAnsweringPipeline(Pipeline):
else:
return current_sql

def sql_dict_to_str(self, result, table):
def sql_dict_to_str(self, result, tables):
"""
convert sql struct to string
"""
table = tables[result['sql']['from'][0]]
header_names = table['header_name'] + ['空列']
header_ids = table['header_id'] + ['null']
sql = result['sql']
@@ -279,42 +281,43 @@ class TableQuestionAnsweringPipeline(Pipeline):
"""
result = inputs['result']
history_sql = inputs['history_sql']
result['sql'] = self.post_process_multi_turn(
history_sql=history_sql,
result=result,
table=self.db.tables[result['table_id']])
result['sql']['from'] = [result['table_id']]
sql = self.sql_dict_to_str(
result=result, table=self.db.tables[result['table_id']])
try:
result['sql'] = self.post_process_multi_turn(
history_sql=history_sql,
result=result,
table=self.db.tables[result['table_id']])
except Exception:
result['sql'] = history_sql
sql = self.sql_dict_to_str(result=result, tables=self.db.tables)

# add sqlite
if self.db.is_use_sqlite:
try:
cursor = self.db.connection_obj.cursor().execute(sql.query)
names = [{
'name':
description[0],
'label':
self.db.tables[result['table_id']]['headerid2name'].get(
description[0], description[0])
} for description in cursor.description]
cells = []
header_ids, header_names = [], []
for description in cursor.description:
header_ids.append(self.db.tables[result['table_id']]
['headerid2name'].get(
description[0], description[0]))
header_names.append(description[0])
rows = []
for res in cursor.fetchall():
row = {}
for name, cell in zip(names, res):
row[name['name']] = cell
cells.append(row)
tabledata = {'headers': names, 'cells': cells}
rows.append(list(res))
tabledata = {
'header_id': header_ids,
'header_name': header_names,
'rows': rows
}
except Exception:
tabledata = {'headers': [], 'cells': []}
tabledata = {'header_id': [], 'header_name': [], 'rows': []}
else:
tabledata = {'headers': [], 'cells': []}
tabledata = {'header_id': [], 'header_name': [], 'rows': []}

output = {
OutputKeys.SQL_STRING: sql.string,
OutputKeys.SQL_QUERY: sql.query,
OutputKeys.HISTORY: result['sql'],
OutputKeys.QUERT_RESULT: json.dumps(tabledata, ensure_ascii=False),
OutputKeys.QUERT_RESULT: tabledata,
}

return output


+ 1
- 1
modelscope/preprocessors/star3/fields/database.py View File

@@ -13,7 +13,7 @@ class Database:
tokenizer,
table_file_path,
syn_dict_file_path,
is_use_sqlite=False):
is_use_sqlite=True):
self.tokenizer = tokenizer
self.is_use_sqlite = is_use_sqlite
if self.is_use_sqlite:


+ 17
- 14
modelscope/preprocessors/star3/fields/schema_link.py View File

@@ -293,6 +293,7 @@ class SchemaLinker:
nlu_t,
tables,
col_syn_dict,
table_id=None,
history_sql=None):
"""
get linking between question and schema column
@@ -300,6 +301,9 @@ class SchemaLinker:
typeinfos = []
numbers = re.findall(r'[-]?\d*\.\d+|[-]?\d+|\d+', nlu)

if table_id is not None and table_id in tables:
tables = {table_id: tables[table_id]}

# search schema link in every table
search_result_list = []
for tablename in tables:
@@ -411,26 +415,25 @@ class SchemaLinker:
# get the match score of each table
match_score = self.get_table_match_score(nlu_t, schema_link)

# cal table_score
if history_sql is not None and 'from' in history_sql:
table_score = int(table['table_id'] == history_sql['from'][0])
else:
table_score = 0

search_result = {
'table_id':
table['table_id'],
'question_knowledge':
final_question,
'header_knowledge':
final_header,
'schema_link':
schema_link,
'match_score':
match_score,
'table_score':
int(table['table_id'] == history_sql['from'][0])
if history_sql is not None else 0
'table_id': table['table_id'],
'question_knowledge': final_question,
'header_knowledge': final_header,
'schema_link': schema_link,
'match_score': match_score,
'table_score': table_score
}
search_result_list.append(search_result)

search_result_list = sorted(
search_result_list,
key=lambda x: (x['match_score'], x['table_score']),
reverse=True)[0:4]
reverse=True)[0:1]

return search_result_list

+ 2
- 0
modelscope/preprocessors/star3/table_question_answering_preprocessor.py View File

@@ -95,6 +95,7 @@ class TableQuestionAnsweringPreprocessor(Preprocessor):

# tokenize question
question = data['question']
table_id = data.get('table_id', None)
history_sql = data.get('history_sql', None)
nlu = question.lower()
nlu_t = self.tokenizer.tokenize(nlu)
@@ -106,6 +107,7 @@ class TableQuestionAnsweringPreprocessor(Preprocessor):
nlu_t=nlu_t,
tables=self.db.tables,
col_syn_dict=self.db.syn_dict,
table_id=table_id,
history_sql=history_sql)

# collect data


+ 49
- 6
tests/pipelines/test_table_question_answering.py View File

@@ -43,7 +43,7 @@ def tableqa_tracking_and_print_results_with_history(
print('sql text:', output_dict[OutputKeys.SQL_STRING])
print('sql query:', output_dict[OutputKeys.SQL_QUERY])
print('query result:', output_dict[OutputKeys.QUERT_RESULT])
print('json dumps', json.dumps(output_dict))
print('json dumps', json.dumps(output_dict, ensure_ascii=False))
print()
historical_queries = output_dict[OutputKeys.HISTORY]

@@ -66,10 +66,42 @@ def tableqa_tracking_and_print_results_without_history(
print('sql text:', output_dict[OutputKeys.SQL_STRING])
print('sql query:', output_dict[OutputKeys.SQL_QUERY])
print('query result:', output_dict[OutputKeys.QUERT_RESULT])
print('json dumps', json.dumps(output_dict))
print('json dumps', json.dumps(output_dict, ensure_ascii=False))
print()


def tableqa_tracking_and_print_results_with_tableid(
pipelines: List[TableQuestionAnsweringPipeline]):
test_case = {
'utterance': [
['有哪些风险类型?', 'fund'],
['风险类型有多少种?', 'reservoir'],
['珠江流域的小(2)型水库的库容总量是多少?', 'reservoir'],
['那平均值是多少?', 'reservoir'],
['那水库的名称呢?', 'reservoir'],
['换成中型的呢?', 'reservoir'],
['枣庄营业厅的电话', 'business'],
['那地址呢?', 'business'],
['枣庄营业厅的电话和地址', 'business'],
],
}
for p in pipelines:
historical_queries = None
for question, table_id in test_case['utterance']:
output_dict = p({
'question': question,
'table_id': table_id,
'history_sql': historical_queries
})
print('question', question)
print('sql text:', output_dict[OutputKeys.SQL_STRING])
print('sql query:', output_dict[OutputKeys.SQL_QUERY])
print('query result:', output_dict[OutputKeys.QUERT_RESULT])
print('json dumps', json.dumps(output_dict, ensure_ascii=False))
print()
historical_queries = output_dict[OutputKeys.HISTORY]


class TableQuestionAnswering(unittest.TestCase):

def setUp(self) -> None:
@@ -93,15 +125,27 @@ class TableQuestionAnswering(unittest.TestCase):
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
self.tokenizer = BertTokenizer(
os.path.join(model.model_dir, ModelFile.VOCAB_FILE))
db = Database(
tokenizer=self.tokenizer,
table_file_path=[
os.path.join(model.model_dir, 'databases', fname)
for fname in os.listdir(
os.path.join(model.model_dir, 'databases'))
],
syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'),
is_use_sqlite=False)
preprocessor = TableQuestionAnsweringPreprocessor(
model_dir=model.model_dir)
model_dir=model.model_dir, db=db)
pipelines = [
pipeline(
Tasks.table_question_answering,
model=model,
preprocessor=preprocessor)
preprocessor=preprocessor,
db=db)
]
tableqa_tracking_and_print_results_with_history(pipelines)
tableqa_tracking_and_print_results_with_tableid(pipelines)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_task(self):
@@ -132,7 +176,6 @@ class TableQuestionAnswering(unittest.TestCase):
db=db)
]
tableqa_tracking_and_print_results_without_history(pipelines)
tableqa_tracking_and_print_results_with_history(pipelines)


if __name__ == '__main__':


Loading…
Cancel
Save