Browse Source

[to #42322933] reivse model problem and remove history sql for demo

相比于master上的tableqa,做出了如下修复:
1. 修复了schema linking中的问题。
2. 同时设置了有history sql和没有history sql的两种输入
3. 增加了sqlite执行逻辑,可以返回sql执行结果
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10365114
master
caorongyu.cry yingda.chen 3 years ago
parent
commit
7145990054
8 changed files with 206 additions and 61 deletions
  1. +2
    -1
      modelscope/models/nlp/table_question_answering.py
  2. +1
    -0
      modelscope/outputs.py
  3. +50
    -11
      modelscope/pipelines/nlp/table_question_answering_pipeline.py
  4. +49
    -4
      modelscope/preprocessors/star3/fields/database.py
  5. +23
    -10
      modelscope/preprocessors/star3/fields/schema_link.py
  6. +3
    -2
      modelscope/preprocessors/star3/table_question_answering_preprocessor.py
  7. +1
    -16
      modelscope/utils/nlp/nlp_utils.py
  8. +77
    -17
      tests/pipelines/test_table_question_answering.py

+ 2
- 1
modelscope/models/nlp/table_question_answering.py View File

@@ -3,9 +3,11 @@
import os import os
from typing import Dict from typing import Dict


import json
import numpy import numpy
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tqdm
from transformers import BertTokenizer from transformers import BertTokenizer


from modelscope.metainfo import Models from modelscope.metainfo import Models
@@ -82,7 +84,6 @@ class TableQuestionAnswering(Model):


if ntok.startswith('##'): if ntok.startswith('##'):
ntok = ntok.replace('##', '') ntok = ntok.replace('##', '')

tok = nlu1[idx:idx + 1].lower() tok = nlu1[idx:idx + 1].lower()
if ntok == tok: if ntok == tok:
conv_dict[i] = [idx, idx + 1] conv_dict[i] = [idx, idx + 1]


+ 1
- 0
modelscope/outputs.py View File

@@ -37,6 +37,7 @@ class OutputKeys(object):
WORD = 'word' WORD = 'word'
KWS_LIST = 'kws_list' KWS_LIST = 'kws_list'
HISTORY = 'history' HISTORY = 'history'
QUERT_RESULT = 'query_result'
TIMESTAMPS = 'timestamps' TIMESTAMPS = 'timestamps'
SHOT_NUM = 'shot_num' SHOT_NUM = 'shot_num'
SCENE_NUM = 'scene_num' SCENE_NUM = 'scene_num'


+ 50
- 11
modelscope/pipelines/nlp/table_question_answering_pipeline.py View File

@@ -2,6 +2,8 @@
import os import os
from typing import Any, Dict, Union from typing import Any, Dict, Union


import json
import torch
from transformers import BertTokenizer from transformers import BertTokenizer


from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
@@ -230,14 +232,16 @@ class TableQuestionAnsweringPipeline(Pipeline):
str_sel_list.append(header_name) str_sel_list.append(header_name)
sql_sel_list.append(header_id) sql_sel_list.append(header_id)
else: else:
str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '( '
+ header_name + ' )')
sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '( '
+ header_id + ' )')
str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '('
+ header_name + ')')
sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '('
+ header_id + ')')


str_cond_list, sql_cond_list = [], [] str_cond_list, sql_cond_list = [], []
for cond in sql['conds']: for cond in sql['conds']:
header_name = header_names[cond[0]] header_name = header_names[cond[0]]
if header_name == '空列':
continue
header_id = '`%s`.`%s`' % (table['table_id'], header_ids[cond[0]]) header_id = '`%s`.`%s`' % (table['table_id'], header_ids[cond[0]])
op = self.cond_ops[cond[1]] op = self.cond_ops[cond[1]]
value = cond[2] value = cond[2]
@@ -248,12 +252,17 @@ class TableQuestionAnsweringPipeline(Pipeline):


cond = ' ' + self.cond_conn_ops[sql['cond_conn_op']] + ' ' cond = ' ' + self.cond_conn_ops[sql['cond_conn_op']] + ' '


final_str = 'SELECT %s FROM %s WHERE %s' % (', '.join(str_sel_list),
table['table_name'],
cond.join(str_cond_list))
final_sql = 'SELECT %s FROM `%s` WHERE %s' % (', '.join(sql_sel_list),
table['table_id'],
cond.join(sql_cond_list))
if len(str_cond_list) != 0:
final_str = 'SELECT %s FROM %s WHERE %s' % (', '.join(
str_sel_list), table['table_name'], cond.join(str_cond_list))
final_sql = 'SELECT %s FROM `%s` WHERE %s' % (', '.join(
sql_sel_list), table['table_id'], cond.join(sql_cond_list))
else:
final_str = 'SELECT %s FROM %s' % (', '.join(str_sel_list),
table['table_name'])
final_sql = 'SELECT %s FROM `%s`' % (', '.join(sql_sel_list),
table['table_id'])

sql = SQLQuery( sql = SQLQuery(
string=final_str, query=final_sql, sql_result=result['sql']) string=final_str, query=final_sql, sql_result=result['sql'])


@@ -274,9 +283,39 @@ class TableQuestionAnsweringPipeline(Pipeline):
history_sql=history_sql, history_sql=history_sql,
result=result, result=result,
table=self.db.tables[result['table_id']]) table=self.db.tables[result['table_id']])
result['sql']['from'] = [result['table_id']]
sql = self.sql_dict_to_str( sql = self.sql_dict_to_str(
result=result, table=self.db.tables[result['table_id']]) result=result, table=self.db.tables[result['table_id']])
output = {OutputKeys.OUTPUT: sql, OutputKeys.HISTORY: result['sql']}

# add sqlite
if self.db.is_use_sqlite:
try:
cursor = self.db.connection_obj.cursor().execute(sql.query)
names = [{
'name':
description[0],
'label':
self.db.tables[result['table_id']]['headerid2name'].get(
description[0], description[0])
} for description in cursor.description]
cells = []
for res in cursor.fetchall():
row = {}
for name, cell in zip(names, res):
row[name['name']] = cell
cells.append(row)
tabledata = {'headers': names, 'cells': cells}
except Exception:
tabledata = {'headers': [], 'cells': []}
else:
tabledata = {'headers': [], 'cells': []}

output = {
OutputKeys.OUTPUT: sql,
OutputKeys.HISTORY: result['sql'],
OutputKeys.QUERT_RESULT: json.dumps(tabledata, ensure_ascii=False),
}

return output return output


def _collate_fn(self, data): def _collate_fn(self, data):


+ 49
- 4
modelscope/preprocessors/star3/fields/database.py View File

@@ -1,4 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import sqlite3

import json import json
import tqdm import tqdm


@@ -7,18 +9,38 @@ from modelscope.preprocessors.star3.fields.struct import Trie


class Database: class Database:


def __init__(self, tokenizer, table_file_path, syn_dict_file_path):
def __init__(self,
tokenizer,
table_file_path,
syn_dict_file_path,
is_use_sqlite=False):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.is_use_sqlite = is_use_sqlite
if self.is_use_sqlite:
self.connection_obj = sqlite3.connect(':memory:')
self.type_dict = {'text': 'TEXT', 'number': 'INT', 'date': 'TEXT'}
self.tables = self.init_tables(table_file_path=table_file_path) self.tables = self.init_tables(table_file_path=table_file_path)
self.syn_dict = self.init_syn_dict( self.syn_dict = self.init_syn_dict(
syn_dict_file_path=syn_dict_file_path) syn_dict_file_path=syn_dict_file_path)


def __del__(self):
if self.is_use_sqlite:
self.connection_obj.close()

def init_tables(self, table_file_path): def init_tables(self, table_file_path):
tables = {} tables = {}
lines = [] lines = []
with open(table_file_path, 'r') as fo:
for line in fo:
lines.append(line)
if type(table_file_path) == str:
with open(table_file_path, 'r') as fo:
for line in fo:
lines.append(line)
elif type(table_file_path) == list:
for path in table_file_path:
with open(path, 'r') as fo:
for line in fo:
lines.append(line)
else:
raise ValueError()


for line in tqdm.tqdm(lines, desc='Load Tables'): for line in tqdm.tqdm(lines, desc='Load Tables'):
table = json.loads(line.strip()) table = json.loads(line.strip())
@@ -34,6 +56,9 @@ class Database:
headers_tokens.append(empty_column) headers_tokens.append(empty_column)
table['tablelen'] = table_header_length table['tablelen'] = table_header_length
table['header_tok'] = headers_tokens table['header_tok'] = headers_tokens
table['headerid2name'] = {}
for hid, hname in zip(table['header_id'], table['header_name']):
table['headerid2name'][hid] = hname


table['header_types'].append('null') table['header_types'].append('null')
table['header_units'] = [ table['header_units'] = [
@@ -51,6 +76,26 @@ class Database:
trie_set[ii].insert(word, word) trie_set[ii].insert(word, word)


table['value_trie'] = trie_set table['value_trie'] = trie_set

# create sqlite
if self.is_use_sqlite:
cursor_obj = self.connection_obj.cursor()
cursor_obj.execute('DROP TABLE IF EXISTS %s' %
(table['table_id']))
header_string = ', '.join([
'%s %s' %
(name, self.type_dict[htype]) for name, htype in zip(
table['header_id'], table['header_types'])
])
create_table_string = 'CREATE TABLE %s (%s);' % (
table['table_id'], header_string)
cursor_obj.execute(create_table_string)
for row in table['rows']:
value_string = ', '.join(['"%s"' % (val) for val in row])
insert_row_string = 'INSERT INTO %s VALUES(%s)' % (
table['table_id'], value_string)
cursor_obj.execute(insert_row_string)

tables[table['table_id']] = table tables[table['table_id']] = table


return tables return tables


+ 23
- 10
modelscope/preprocessors/star3/fields/schema_link.py View File

@@ -287,7 +287,13 @@ class SchemaLinker:


return match_len / (len(nlu_t) + 0.1) return match_len / (len(nlu_t) + 0.1)


def get_entity_linking(self, tokenizer, nlu, nlu_t, tables, col_syn_dict):
def get_entity_linking(self,
tokenizer,
nlu,
nlu_t,
tables,
col_syn_dict,
history_sql=None):
""" """
get linking between question and schema column get linking between question and schema column
""" """
@@ -305,8 +311,7 @@ class SchemaLinker:
typeinfos = [] typeinfos = []
for ii, column in enumerate(table['header_name']): for ii, column in enumerate(table['header_name']):
column = column.lower() column = column.lower()
column_new = re.sub('(.*?)', '', column)
column_new = re.sub('(.*?)', '', column_new)
column_new = column
cphrase, cscore = self.get_match_phrase( cphrase, cscore = self.get_match_phrase(
nlu.lower(), column_new) nlu.lower(), column_new)
if cscore > 0.3 and cphrase.strip() != '': if cscore > 0.3 and cphrase.strip() != '':
@@ -330,7 +335,6 @@ class SchemaLinker:
for cell in ans.keys(): for cell in ans.keys():
vphrase = cell vphrase = cell
vscore = 1.0 vscore = 1.0
# print("trie_set find:", cell, ans[cell])
phrase_tok = tokenizer.tokenize(vphrase) phrase_tok = tokenizer.tokenize(vphrase)
if len(phrase_tok) == 0 or len(vphrase) < 2: if len(phrase_tok) == 0 or len(vphrase) < 2:
continue continue
@@ -408,16 +412,25 @@ class SchemaLinker:
match_score = self.get_table_match_score(nlu_t, schema_link) match_score = self.get_table_match_score(nlu_t, schema_link)


search_result = { search_result = {
'table_id': table['table_id'],
'question_knowledge': final_question,
'header_knowledge': final_header,
'schema_link': schema_link,
'match_score': match_score
'table_id':
table['table_id'],
'question_knowledge':
final_question,
'header_knowledge':
final_header,
'schema_link':
schema_link,
'match_score':
match_score,
'table_score':
int(table['table_id'] == history_sql['from'][0])
if history_sql is not None else 0
} }
search_result_list.append(search_result) search_result_list.append(search_result)


search_result_list = sorted( search_result_list = sorted(
search_result_list, key=lambda x: x['match_score'],
search_result_list,
key=lambda x: (x['match_score'], x['table_score']),
reverse=True)[0:4] reverse=True)[0:4]


return search_result_list return search_result_list

+ 3
- 2
modelscope/preprocessors/star3/table_question_answering_preprocessor.py View File

@@ -95,7 +95,7 @@ class TableQuestionAnsweringPreprocessor(Preprocessor):


# tokenize question # tokenize question
question = data['question'] question = data['question']
history_sql = data['history_sql']
history_sql = data.get('history_sql', None)
nlu = question.lower() nlu = question.lower()
nlu_t = self.tokenizer.tokenize(nlu) nlu_t = self.tokenizer.tokenize(nlu)


@@ -105,7 +105,8 @@ class TableQuestionAnsweringPreprocessor(Preprocessor):
nlu=nlu, nlu=nlu,
nlu_t=nlu_t, nlu_t=nlu_t,
tables=self.db.tables, tables=self.db.tables,
col_syn_dict=self.db.syn_dict)
col_syn_dict=self.db.syn_dict,
history_sql=history_sql)


# collect data # collect data
datas = self.construct_data( datas = self.construct_data(


+ 1
- 16
modelscope/utils/nlp/nlp_utils.py View File

@@ -2,8 +2,7 @@ from typing import List


from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
from modelscope.pipelines.nlp import (ConversationalTextToSqlPipeline, from modelscope.pipelines.nlp import (ConversationalTextToSqlPipeline,
DialogStateTrackingPipeline,
TableQuestionAnsweringPipeline)
DialogStateTrackingPipeline)




def text2sql_tracking_and_print_results( def text2sql_tracking_and_print_results(
@@ -42,17 +41,3 @@ def tracking_and_print_dialog_states(
print(json.dumps(result)) print(json.dumps(result))


history_states.extend([result[OutputKeys.OUTPUT], {}]) history_states.extend([result[OutputKeys.OUTPUT], {}])


def tableqa_tracking_and_print_results(
test_case, pipelines: List[TableQuestionAnsweringPipeline]):
for pipeline in pipelines:
historical_queries = None
for question in test_case['utterance']:
output_dict = pipeline({
'question': question,
'history_sql': historical_queries
})
print('output_dict', output_dict['output'].string,
output_dict['output'].query)
historical_queries = output_dict['history']

+ 77
- 17
tests/pipelines/test_table_question_answering.py View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import os import os
import unittest import unittest
from typing import List


from transformers import BertTokenizer from transformers import BertTokenizer


@@ -11,10 +12,60 @@ from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline
from modelscope.preprocessors import TableQuestionAnsweringPreprocessor from modelscope.preprocessors import TableQuestionAnsweringPreprocessor
from modelscope.preprocessors.star3.fields.database import Database from modelscope.preprocessors.star3.fields.database import Database
from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.nlp.nlp_utils import tableqa_tracking_and_print_results
from modelscope.utils.test_utils import test_level from modelscope.utils.test_utils import test_level




def tableqa_tracking_and_print_results_with_history(
pipelines: List[TableQuestionAnsweringPipeline]):
test_case = {
'utterance': [
'有哪些风险类型?',
'风险类型有多少种?',
'珠江流域的小(2)型水库的库容总量是多少?',
'那平均值是多少?',
'那水库的名称呢?',
'换成中型的呢?',
'枣庄营业厅的电话',
'那地址呢?',
'枣庄营业厅的电话和地址',
]
}
for p in pipelines:
historical_queries = None
for question in test_case['utterance']:
output_dict = p({
'question': question,
'history_sql': historical_queries
})
print('question', question)
print('sql text:', output_dict['output'].string)
print('sql query:', output_dict['output'].query)
print('query result:', output_dict['query_result'])
print()
historical_queries = output_dict['history']


def tableqa_tracking_and_print_results_without_history(
pipelines: List[TableQuestionAnsweringPipeline]):
test_case = {
'utterance': [
'有哪些风险类型?',
'风险类型有多少种?',
'珠江流域的小(2)型水库的库容总量是多少?',
'枣庄营业厅的电话',
'枣庄营业厅的电话和地址',
]
}
for p in pipelines:
for question in test_case['utterance']:
output_dict = p({'question': question})
print('question', question)
print('sql text:', output_dict['output'].string)
print('sql query:', output_dict['output'].query)
print('query result:', output_dict['query_result'])
print()


class TableQuestionAnswering(unittest.TestCase): class TableQuestionAnswering(unittest.TestCase):


def setUp(self) -> None: def setUp(self) -> None:
@@ -22,20 +73,18 @@ class TableQuestionAnswering(unittest.TestCase):
self.model_id = 'damo/nlp_convai_text2sql_pretrain_cn' self.model_id = 'damo/nlp_convai_text2sql_pretrain_cn'


model_id = 'damo/nlp_convai_text2sql_pretrain_cn' model_id = 'damo/nlp_convai_text2sql_pretrain_cn'
test_case = {
'utterance':
['长江流域的小(2)型水库的库容总量是多少?', '那平均值是多少?', '那水库的名称呢?', '换成中型的呢?']
}


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self): def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id) cache_path = snapshot_download(self.model_id)
preprocessor = TableQuestionAnsweringPreprocessor(model_dir=cache_path) preprocessor = TableQuestionAnsweringPreprocessor(model_dir=cache_path)
pipelines = [ pipelines = [
TableQuestionAnsweringPipeline(
model=cache_path, preprocessor=preprocessor)
pipeline(
Tasks.table_question_answering,
model=cache_path,
preprocessor=preprocessor)
] ]
tableqa_tracking_and_print_results(self.test_case, pipelines)
tableqa_tracking_and_print_results_with_history(pipelines)


@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self): def test_run_with_model_from_modelhub(self):
@@ -43,15 +92,17 @@ class TableQuestionAnswering(unittest.TestCase):
preprocessor = TableQuestionAnsweringPreprocessor( preprocessor = TableQuestionAnsweringPreprocessor(
model_dir=model.model_dir) model_dir=model.model_dir)
pipelines = [ pipelines = [
TableQuestionAnsweringPipeline(
model=model, preprocessor=preprocessor)
pipeline(
Tasks.table_question_answering,
model=model,
preprocessor=preprocessor)
] ]
tableqa_tracking_and_print_results(self.test_case, pipelines)
tableqa_tracking_and_print_results_with_history(pipelines)


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_task(self): def test_run_with_model_from_task(self):
pipelines = [pipeline(Tasks.table_question_answering, self.model_id)] pipelines = [pipeline(Tasks.table_question_answering, self.model_id)]
tableqa_tracking_and_print_results(self.test_case, pipelines)
tableqa_tracking_and_print_results_with_history(pipelines)


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_model_from_modelhub_with_other_classes(self): def test_run_with_model_from_modelhub_with_other_classes(self):
@@ -60,15 +111,24 @@ class TableQuestionAnswering(unittest.TestCase):
os.path.join(model.model_dir, ModelFile.VOCAB_FILE)) os.path.join(model.model_dir, ModelFile.VOCAB_FILE))
db = Database( db = Database(
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
table_file_path=os.path.join(model.model_dir, 'table.json'),
syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'))
table_file_path=[
os.path.join(model.model_dir, 'databases', fname)
for fname in os.listdir(
os.path.join(model.model_dir, 'databases'))
],
syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'),
is_use_sqlite=True)
preprocessor = TableQuestionAnsweringPreprocessor( preprocessor = TableQuestionAnsweringPreprocessor(
model_dir=model.model_dir, db=db) model_dir=model.model_dir, db=db)
pipelines = [ pipelines = [
TableQuestionAnsweringPipeline(
model=model, preprocessor=preprocessor, db=db)
pipeline(
Tasks.table_question_answering,
model=model,
preprocessor=preprocessor,
db=db)
] ]
tableqa_tracking_and_print_results(self.test_case, pipelines)
tableqa_tracking_and_print_results_without_history(pipelines)
tableqa_tracking_and_print_results_with_history(pipelines)




if __name__ == '__main__': if __name__ == '__main__':


Loading…
Cancel
Save