Browse Source

[to #42322933] debug header ids and header names

修复header_ids和header_names命名反了的问题
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10516557
master
caorongyu.cry yingda.chen 3 years ago
parent
commit
3b8fb92c13
3 changed files with 81 additions and 15 deletions
  1. +54
    -12
      modelscope/pipelines/nlp/table_question_answering_pipeline.py
  2. +22
    -0
      modelscope/preprocessors/nlp/space_T_cn/fields/struct.py
  3. +5
    -3
      tests/pipelines/test_table_question_answering.py

+ 54
- 12
modelscope/pipelines/nlp/table_question_answering_pipeline.py View File

@@ -69,6 +69,7 @@ class TableQuestionAnsweringPipeline(Pipeline):
self.max_where_num = constant.max_where_num
self.col_type_dict = constant.col_type_dict
self.schema_link_dict = constant.schema_link_dict
self.limit_dict = constant.limit_dict

super().__init__(model=model, preprocessor=preprocessor, **kwargs)

@@ -244,7 +245,13 @@ class TableQuestionAnsweringPipeline(Pipeline):
+ header_id + ')')

str_cond_list, sql_cond_list = [], []
where_conds, orderby_conds = [], []
for cond in sql['conds']:
if cond[1] in [4, 5]:
orderby_conds.append(cond)
else:
where_conds.append(cond)
for cond in where_conds:
header_name = header_names[cond[0]]
if header_name == '空列':
continue
@@ -255,14 +262,49 @@ class TableQuestionAnsweringPipeline(Pipeline):
+ '" )')
sql_cond_list.append('( ' + header_id + ' ' + op + ' "' + value
+ '" )')

cond = ' ' + self.cond_conn_ops[sql['cond_conn_op']] + ' '

if len(str_cond_list) != 0:
final_str = 'SELECT %s FROM %s WHERE %s' % (', '.join(
str_sel_list), table['table_name'], cond.join(str_cond_list))
final_sql = 'SELECT %s FROM `%s` WHERE %s' % (', '.join(
sql_sel_list), table['table_id'], cond.join(sql_cond_list))
cond_str = ' ' + self.cond_conn_ops[sql['cond_conn_op']] + ' '
str_where_conds = cond_str.join(str_cond_list)
sql_where_conds = cond_str.join(sql_cond_list)
if len(orderby_conds) != 0:
str_orderby_column = ', '.join(
[header_names[cond[0]] for cond in orderby_conds])
sql_orderby_column = ', '.join([
'`%s`.`%s`' % (table['table_id'], header_ids[cond[0]])
for cond in orderby_conds
])
str_orderby_op = self.cond_ops[orderby_conds[0][1]]
str_orderby = '%s %s' % (str_orderby_column, str_orderby_op)
sql_orderby = '%s %s' % (sql_orderby_column, str_orderby_op)
limit_key = orderby_conds[0][2]
is_in, limit_num = False, -1
for key in self.limit_dict:
if key in limit_key:
is_in = True
limit_num = self.limit_dict[key]
break
if is_in:
str_orderby += ' LIMIT %d' % (limit_num)
sql_orderby += ' LIMIT %d' % (limit_num)
else:
str_orderby = ''

if len(str_cond_list) != 0 and len(str_orderby) != 0:
final_str = 'SELECT %s FROM %s WHERE %s ORDER BY %s' % (
', '.join(str_sel_list), table['table_name'], str_where_conds,
str_orderby)
final_sql = 'SELECT %s FROM `%s` WHERE %s ORDER BY %s' % (
', '.join(sql_sel_list), table['table_id'], sql_where_conds,
sql_orderby)
elif len(str_cond_list) != 0:
final_str = 'SELECT %s FROM %s WHERE %s' % (
', '.join(str_sel_list), table['table_name'], str_where_conds)
final_sql = 'SELECT %s FROM `%s` WHERE %s' % (
', '.join(sql_sel_list), table['table_id'], sql_where_conds)
elif len(str_orderby) != 0:
final_str = 'SELECT %s FROM %s ORDER BY %s' % (
', '.join(str_sel_list), table['table_name'], str_orderby)
final_sql = 'SELECT %s FROM `%s` ORDER BY %s' % (
', '.join(sql_sel_list), table['table_id'], sql_orderby)
else:
final_str = 'SELECT %s FROM %s' % (', '.join(str_sel_list),
table['table_name'])
@@ -300,10 +342,10 @@ class TableQuestionAnsweringPipeline(Pipeline):
cursor = self.db.connection_obj.cursor().execute(sql.query)
header_ids, header_names = [], []
for description in cursor.description:
header_ids.append(self.db.tables[result['table_id']]
['headerid2name'].get(
description[0], description[0]))
header_names.append(description[0])
header_names.append(self.db.tables[result['table_id']]
['headerid2name'].get(
description[0], description[0]))
header_ids.append(description[0])
rows = []
for res in cursor.fetchall():
rows.append(list(res))


+ 22
- 0
modelscope/preprocessors/nlp/space_T_cn/fields/struct.py View File

@@ -179,3 +179,25 @@ class Constant:
self.max_select_num = 4

self.max_where_num = 6

self.limit_dict = {
'最': 1,
'1': 1,
'一': 1,
'2': 2,
'二': 2,
'3': 3,
'三': 3,
'4': 4,
'四': 4,
'5': 5,
'五': 5,
'6': 6,
'六': 6,
'7': 7,
'七': 7,
'8': 8,
'八': 8,
'9': 9,
'九': 9
}

+ 5
- 3
tests/pipelines/test_table_question_answering.py View File

@@ -128,11 +128,13 @@ class TableQuestionAnswering(unittest.TestCase):

def print_func(pl, i):
result = pl({
'question': '长江流域的小(2)型水库的库容总量是多少?',
'table_id': 'reservoir',
'question': '上个月收益从低到高排前七的基金的名称和风险等级是什么',
'table_id': 'fund',
'history_sql': None
})
print(i, json.dumps(result))
print(i, result[OutputKeys.OUTPUT][OutputKeys.SQL_QUERY],
result[OutputKeys.OUTPUT][OutputKeys.QUERT_RESULT],
json.dumps(result))

procs = []
for i in range(5):


Loading…
Cancel
Save