Browse Source

[to #42322933] reivse model problem and remove history sql for demo

相比于master上的tableqa,做出了如下修复:
1. 修复了schema linking中的问题。
2. 同时设置了有history sql和没有history sql的两种输入
3. 增加了sqlite执行逻辑,可以返回sql执行结果
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10365114
master
caorongyu.cry yingda.chen 3 years ago
parent
commit
7145990054
8 changed files with 206 additions and 61 deletions
  1. +2
    -1
      modelscope/models/nlp/table_question_answering.py
  2. +1
    -0
      modelscope/outputs.py
  3. +50
    -11
      modelscope/pipelines/nlp/table_question_answering_pipeline.py
  4. +49
    -4
      modelscope/preprocessors/star3/fields/database.py
  5. +23
    -10
      modelscope/preprocessors/star3/fields/schema_link.py
  6. +3
    -2
      modelscope/preprocessors/star3/table_question_answering_preprocessor.py
  7. +1
    -16
      modelscope/utils/nlp/nlp_utils.py
  8. +77
    -17
      tests/pipelines/test_table_question_answering.py

+ 2
- 1
modelscope/models/nlp/table_question_answering.py View File

@@ -3,9 +3,11 @@
import os
from typing import Dict

import json
import numpy
import torch
import torch.nn.functional as F
import tqdm
from transformers import BertTokenizer

from modelscope.metainfo import Models
@@ -82,7 +84,6 @@ class TableQuestionAnswering(Model):

if ntok.startswith('##'):
ntok = ntok.replace('##', '')

tok = nlu1[idx:idx + 1].lower()
if ntok == tok:
conv_dict[i] = [idx, idx + 1]


+ 1
- 0
modelscope/outputs.py View File

@@ -37,6 +37,7 @@ class OutputKeys(object):
WORD = 'word'
KWS_LIST = 'kws_list'
HISTORY = 'history'
QUERT_RESULT = 'query_result'
TIMESTAMPS = 'timestamps'
SHOT_NUM = 'shot_num'
SCENE_NUM = 'scene_num'


+ 50
- 11
modelscope/pipelines/nlp/table_question_answering_pipeline.py View File

@@ -2,6 +2,8 @@
import os
from typing import Any, Dict, Union

import json
import torch
from transformers import BertTokenizer

from modelscope.metainfo import Pipelines
@@ -230,14 +232,16 @@ class TableQuestionAnsweringPipeline(Pipeline):
str_sel_list.append(header_name)
sql_sel_list.append(header_id)
else:
str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '( '
+ header_name + ' )')
sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '( '
+ header_id + ' )')
str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '('
+ header_name + ')')
sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '('
+ header_id + ')')

str_cond_list, sql_cond_list = [], []
for cond in sql['conds']:
header_name = header_names[cond[0]]
if header_name == '空列':
continue
header_id = '`%s`.`%s`' % (table['table_id'], header_ids[cond[0]])
op = self.cond_ops[cond[1]]
value = cond[2]
@@ -248,12 +252,17 @@ class TableQuestionAnsweringPipeline(Pipeline):

cond = ' ' + self.cond_conn_ops[sql['cond_conn_op']] + ' '

final_str = 'SELECT %s FROM %s WHERE %s' % (', '.join(str_sel_list),
table['table_name'],
cond.join(str_cond_list))
final_sql = 'SELECT %s FROM `%s` WHERE %s' % (', '.join(sql_sel_list),
table['table_id'],
cond.join(sql_cond_list))
if len(str_cond_list) != 0:
final_str = 'SELECT %s FROM %s WHERE %s' % (', '.join(
str_sel_list), table['table_name'], cond.join(str_cond_list))
final_sql = 'SELECT %s FROM `%s` WHERE %s' % (', '.join(
sql_sel_list), table['table_id'], cond.join(sql_cond_list))
else:
final_str = 'SELECT %s FROM %s' % (', '.join(str_sel_list),
table['table_name'])
final_sql = 'SELECT %s FROM `%s`' % (', '.join(sql_sel_list),
table['table_id'])

sql = SQLQuery(
string=final_str, query=final_sql, sql_result=result['sql'])

@@ -274,9 +283,39 @@ class TableQuestionAnsweringPipeline(Pipeline):
history_sql=history_sql,
result=result,
table=self.db.tables[result['table_id']])
result['sql']['from'] = [result['table_id']]
sql = self.sql_dict_to_str(
result=result, table=self.db.tables[result['table_id']])
output = {OutputKeys.OUTPUT: sql, OutputKeys.HISTORY: result['sql']}

# add sqlite
if self.db.is_use_sqlite:
try:
cursor = self.db.connection_obj.cursor().execute(sql.query)
names = [{
'name':
description[0],
'label':
self.db.tables[result['table_id']]['headerid2name'].get(
description[0], description[0])
} for description in cursor.description]
cells = []
for res in cursor.fetchall():
row = {}
for name, cell in zip(names, res):
row[name['name']] = cell
cells.append(row)
tabledata = {'headers': names, 'cells': cells}
except Exception:
tabledata = {'headers': [], 'cells': []}
else:
tabledata = {'headers': [], 'cells': []}

output = {
OutputKeys.OUTPUT: sql,
OutputKeys.HISTORY: result['sql'],
OutputKeys.QUERT_RESULT: json.dumps(tabledata, ensure_ascii=False),
}

return output

def _collate_fn(self, data):


+ 49
- 4
modelscope/preprocessors/star3/fields/database.py View File

@@ -1,4 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import sqlite3

import json
import tqdm

@@ -7,18 +9,38 @@ from modelscope.preprocessors.star3.fields.struct import Trie

class Database:

def __init__(self, tokenizer, table_file_path, syn_dict_file_path):
def __init__(self,
tokenizer,
table_file_path,
syn_dict_file_path,
is_use_sqlite=False):
self.tokenizer = tokenizer
self.is_use_sqlite = is_use_sqlite
if self.is_use_sqlite:
self.connection_obj = sqlite3.connect(':memory:')
self.type_dict = {'text': 'TEXT', 'number': 'INT', 'date': 'TEXT'}
self.tables = self.init_tables(table_file_path=table_file_path)
self.syn_dict = self.init_syn_dict(
syn_dict_file_path=syn_dict_file_path)

def __del__(self):
if self.is_use_sqlite:
self.connection_obj.close()

def init_tables(self, table_file_path):
tables = {}
lines = []
with open(table_file_path, 'r') as fo:
for line in fo:
lines.append(line)
if type(table_file_path) == str:
with open(table_file_path, 'r') as fo:
for line in fo:
lines.append(line)
elif type(table_file_path) == list:
for path in table_file_path:
with open(path, 'r') as fo:
for line in fo:
lines.append(line)
else:
raise ValueError()

for line in tqdm.tqdm(lines, desc='Load Tables'):
table = json.loads(line.strip())
@@ -34,6 +56,9 @@ class Database:
headers_tokens.append(empty_column)
table['tablelen'] = table_header_length
table['header_tok'] = headers_tokens
table['headerid2name'] = {}
for hid, hname in zip(table['header_id'], table['header_name']):
table['headerid2name'][hid] = hname

table['header_types'].append('null')
table['header_units'] = [
@@ -51,6 +76,26 @@ class Database:
trie_set[ii].insert(word, word)

table['value_trie'] = trie_set

# create sqlite
if self.is_use_sqlite:
cursor_obj = self.connection_obj.cursor()
cursor_obj.execute('DROP TABLE IF EXISTS %s' %
(table['table_id']))
header_string = ', '.join([
'%s %s' %
(name, self.type_dict[htype]) for name, htype in zip(
table['header_id'], table['header_types'])
])
create_table_string = 'CREATE TABLE %s (%s);' % (
table['table_id'], header_string)
cursor_obj.execute(create_table_string)
for row in table['rows']:
value_string = ', '.join(['"%s"' % (val) for val in row])
insert_row_string = 'INSERT INTO %s VALUES(%s)' % (
table['table_id'], value_string)
cursor_obj.execute(insert_row_string)

tables[table['table_id']] = table

return tables


+ 23
- 10
modelscope/preprocessors/star3/fields/schema_link.py View File

@@ -287,7 +287,13 @@ class SchemaLinker:

return match_len / (len(nlu_t) + 0.1)

def get_entity_linking(self, tokenizer, nlu, nlu_t, tables, col_syn_dict):
def get_entity_linking(self,
tokenizer,
nlu,
nlu_t,
tables,
col_syn_dict,
history_sql=None):
"""
get linking between question and schema column
"""
@@ -305,8 +311,7 @@ class SchemaLinker:
typeinfos = []
for ii, column in enumerate(table['header_name']):
column = column.lower()
column_new = re.sub('(.*?)', '', column)
column_new = re.sub('(.*?)', '', column_new)
column_new = column
cphrase, cscore = self.get_match_phrase(
nlu.lower(), column_new)
if cscore > 0.3 and cphrase.strip() != '':
@@ -330,7 +335,6 @@ class SchemaLinker:
for cell in ans.keys():
vphrase = cell
vscore = 1.0
# print("trie_set find:", cell, ans[cell])
phrase_tok = tokenizer.tokenize(vphrase)
if len(phrase_tok) == 0 or len(vphrase) < 2:
continue
@@ -408,16 +412,25 @@ class SchemaLinker:
match_score = self.get_table_match_score(nlu_t, schema_link)

search_result = {
'table_id': table['table_id'],
'question_knowledge': final_question,
'header_knowledge': final_header,
'schema_link': schema_link,
'match_score': match_score
'table_id':
table['table_id'],
'question_knowledge':
final_question,
'header_knowledge':
final_header,
'schema_link':
schema_link,
'match_score':
match_score,
'table_score':
int(table['table_id'] == history_sql['from'][0])
if history_sql is not None else 0
}
search_result_list.append(search_result)

search_result_list = sorted(
search_result_list, key=lambda x: x['match_score'],
search_result_list,
key=lambda x: (x['match_score'], x['table_score']),
reverse=True)[0:4]

return search_result_list

+ 3
- 2
modelscope/preprocessors/star3/table_question_answering_preprocessor.py View File

@@ -95,7 +95,7 @@ class TableQuestionAnsweringPreprocessor(Preprocessor):

# tokenize question
question = data['question']
history_sql = data['history_sql']
history_sql = data.get('history_sql', None)
nlu = question.lower()
nlu_t = self.tokenizer.tokenize(nlu)

@@ -105,7 +105,8 @@ class TableQuestionAnsweringPreprocessor(Preprocessor):
nlu=nlu,
nlu_t=nlu_t,
tables=self.db.tables,
col_syn_dict=self.db.syn_dict)
col_syn_dict=self.db.syn_dict,
history_sql=history_sql)

# collect data
datas = self.construct_data(


+ 1
- 16
modelscope/utils/nlp/nlp_utils.py View File

@@ -2,8 +2,7 @@ from typing import List

from modelscope.outputs import OutputKeys
from modelscope.pipelines.nlp import (ConversationalTextToSqlPipeline,
DialogStateTrackingPipeline,
TableQuestionAnsweringPipeline)
DialogStateTrackingPipeline)


def text2sql_tracking_and_print_results(
@@ -42,17 +41,3 @@ def tracking_and_print_dialog_states(
print(json.dumps(result))

history_states.extend([result[OutputKeys.OUTPUT], {}])


def tableqa_tracking_and_print_results(
test_case, pipelines: List[TableQuestionAnsweringPipeline]):
for pipeline in pipelines:
historical_queries = None
for question in test_case['utterance']:
output_dict = pipeline({
'question': question,
'history_sql': historical_queries
})
print('output_dict', output_dict['output'].string,
output_dict['output'].query)
historical_queries = output_dict['history']

+ 77
- 17
tests/pipelines/test_table_question_answering.py View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest
from typing import List

from transformers import BertTokenizer

@@ -11,10 +12,60 @@ from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline
from modelscope.preprocessors import TableQuestionAnsweringPreprocessor
from modelscope.preprocessors.star3.fields.database import Database
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.nlp.nlp_utils import tableqa_tracking_and_print_results
from modelscope.utils.test_utils import test_level


def tableqa_tracking_and_print_results_with_history(
pipelines: List[TableQuestionAnsweringPipeline]):
test_case = {
'utterance': [
'有哪些风险类型?',
'风险类型有多少种?',
'珠江流域的小(2)型水库的库容总量是多少?',
'那平均值是多少?',
'那水库的名称呢?',
'换成中型的呢?',
'枣庄营业厅的电话',
'那地址呢?',
'枣庄营业厅的电话和地址',
]
}
for p in pipelines:
historical_queries = None
for question in test_case['utterance']:
output_dict = p({
'question': question,
'history_sql': historical_queries
})
print('question', question)
print('sql text:', output_dict['output'].string)
print('sql query:', output_dict['output'].query)
print('query result:', output_dict['query_result'])
print()
historical_queries = output_dict['history']


def tableqa_tracking_and_print_results_without_history(
pipelines: List[TableQuestionAnsweringPipeline]):
test_case = {
'utterance': [
'有哪些风险类型?',
'风险类型有多少种?',
'珠江流域的小(2)型水库的库容总量是多少?',
'枣庄营业厅的电话',
'枣庄营业厅的电话和地址',
]
}
for p in pipelines:
for question in test_case['utterance']:
output_dict = p({'question': question})
print('question', question)
print('sql text:', output_dict['output'].string)
print('sql query:', output_dict['output'].query)
print('query result:', output_dict['query_result'])
print()


class TableQuestionAnswering(unittest.TestCase):

def setUp(self) -> None:
@@ -22,20 +73,18 @@ class TableQuestionAnswering(unittest.TestCase):
self.model_id = 'damo/nlp_convai_text2sql_pretrain_cn'

model_id = 'damo/nlp_convai_text2sql_pretrain_cn'
test_case = {
'utterance':
['长江流域的小(2)型水库的库容总量是多少?', '那平均值是多少?', '那水库的名称呢?', '换成中型的呢?']
}

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)
preprocessor = TableQuestionAnsweringPreprocessor(model_dir=cache_path)
pipelines = [
TableQuestionAnsweringPipeline(
model=cache_path, preprocessor=preprocessor)
pipeline(
Tasks.table_question_answering,
model=cache_path,
preprocessor=preprocessor)
]
tableqa_tracking_and_print_results(self.test_case, pipelines)
tableqa_tracking_and_print_results_with_history(pipelines)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
@@ -43,15 +92,17 @@ class TableQuestionAnswering(unittest.TestCase):
preprocessor = TableQuestionAnsweringPreprocessor(
model_dir=model.model_dir)
pipelines = [
TableQuestionAnsweringPipeline(
model=model, preprocessor=preprocessor)
pipeline(
Tasks.table_question_answering,
model=model,
preprocessor=preprocessor)
]
tableqa_tracking_and_print_results(self.test_case, pipelines)
tableqa_tracking_and_print_results_with_history(pipelines)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_task(self):
pipelines = [pipeline(Tasks.table_question_answering, self.model_id)]
tableqa_tracking_and_print_results(self.test_case, pipelines)
tableqa_tracking_and_print_results_with_history(pipelines)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_model_from_modelhub_with_other_classes(self):
@@ -60,15 +111,24 @@ class TableQuestionAnswering(unittest.TestCase):
os.path.join(model.model_dir, ModelFile.VOCAB_FILE))
db = Database(
tokenizer=self.tokenizer,
table_file_path=os.path.join(model.model_dir, 'table.json'),
syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'))
table_file_path=[
os.path.join(model.model_dir, 'databases', fname)
for fname in os.listdir(
os.path.join(model.model_dir, 'databases'))
],
syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'),
is_use_sqlite=True)
preprocessor = TableQuestionAnsweringPreprocessor(
model_dir=model.model_dir, db=db)
pipelines = [
TableQuestionAnsweringPipeline(
model=model, preprocessor=preprocessor, db=db)
pipeline(
Tasks.table_question_answering,
model=model,
preprocessor=preprocessor,
db=db)
]
tableqa_tracking_and_print_results(self.test_case, pipelines)
tableqa_tracking_and_print_results_without_history(pipelines)
tableqa_tracking_and_print_results_with_history(pipelines)


if __name__ == '__main__':


Loading…
Cancel
Save