|
|
|
@@ -69,6 +69,7 @@ class TableQuestionAnsweringPipeline(Pipeline): |
|
|
|
self.max_where_num = constant.max_where_num |
|
|
|
self.col_type_dict = constant.col_type_dict |
|
|
|
self.schema_link_dict = constant.schema_link_dict |
|
|
|
self.limit_dict = constant.limit_dict |
|
|
|
|
|
|
|
super().__init__(model=model, preprocessor=preprocessor, **kwargs) |
|
|
|
|
|
|
|
@@ -244,7 +245,13 @@ class TableQuestionAnsweringPipeline(Pipeline): |
|
|
|
+ header_id + ')') |
|
|
|
|
|
|
|
str_cond_list, sql_cond_list = [], [] |
|
|
|
where_conds, orderby_conds = [], [] |
|
|
|
for cond in sql['conds']: |
|
|
|
if cond[1] in [4, 5]: |
|
|
|
orderby_conds.append(cond) |
|
|
|
else: |
|
|
|
where_conds.append(cond) |
|
|
|
for cond in where_conds: |
|
|
|
header_name = header_names[cond[0]] |
|
|
|
if header_name == '空列': |
|
|
|
continue |
|
|
|
@@ -255,14 +262,49 @@ class TableQuestionAnsweringPipeline(Pipeline): |
|
|
|
+ '" )') |
|
|
|
sql_cond_list.append('( ' + header_id + ' ' + op + ' "' + value |
|
|
|
+ '" )') |
|
|
|
|
|
|
|
cond = ' ' + self.cond_conn_ops[sql['cond_conn_op']] + ' ' |
|
|
|
|
|
|
|
if len(str_cond_list) != 0: |
|
|
|
final_str = 'SELECT %s FROM %s WHERE %s' % (', '.join( |
|
|
|
str_sel_list), table['table_name'], cond.join(str_cond_list)) |
|
|
|
final_sql = 'SELECT %s FROM `%s` WHERE %s' % (', '.join( |
|
|
|
sql_sel_list), table['table_id'], cond.join(sql_cond_list)) |
|
|
|
cond_str = ' ' + self.cond_conn_ops[sql['cond_conn_op']] + ' ' |
|
|
|
str_where_conds = cond_str.join(str_cond_list) |
|
|
|
sql_where_conds = cond_str.join(sql_cond_list) |
|
|
|
if len(orderby_conds) != 0: |
|
|
|
str_orderby_column = ', '.join( |
|
|
|
[header_names[cond[0]] for cond in orderby_conds]) |
|
|
|
sql_orderby_column = ', '.join([ |
|
|
|
'`%s`.`%s`' % (table['table_id'], header_ids[cond[0]]) |
|
|
|
for cond in orderby_conds |
|
|
|
]) |
|
|
|
str_orderby_op = self.cond_ops[orderby_conds[0][1]] |
|
|
|
str_orderby = '%s %s' % (str_orderby_column, str_orderby_op) |
|
|
|
sql_orderby = '%s %s' % (sql_orderby_column, str_orderby_op) |
|
|
|
limit_key = orderby_conds[0][2] |
|
|
|
is_in, limit_num = False, -1 |
|
|
|
for key in self.limit_dict: |
|
|
|
if key in limit_key: |
|
|
|
is_in = True |
|
|
|
limit_num = self.limit_dict[key] |
|
|
|
break |
|
|
|
if is_in: |
|
|
|
str_orderby += ' LIMIT %d' % (limit_num) |
|
|
|
sql_orderby += ' LIMIT %d' % (limit_num) |
|
|
|
else: |
|
|
|
str_orderby = '' |
|
|
|
|
|
|
|
if len(str_cond_list) != 0 and len(str_orderby) != 0: |
|
|
|
final_str = 'SELECT %s FROM %s WHERE %s ORDER BY %s' % ( |
|
|
|
', '.join(str_sel_list), table['table_name'], str_where_conds, |
|
|
|
str_orderby) |
|
|
|
final_sql = 'SELECT %s FROM `%s` WHERE %s ORDER BY %s' % ( |
|
|
|
', '.join(sql_sel_list), table['table_id'], sql_where_conds, |
|
|
|
sql_orderby) |
|
|
|
elif len(str_cond_list) != 0: |
|
|
|
final_str = 'SELECT %s FROM %s WHERE %s' % ( |
|
|
|
', '.join(str_sel_list), table['table_name'], str_where_conds) |
|
|
|
final_sql = 'SELECT %s FROM `%s` WHERE %s' % ( |
|
|
|
', '.join(sql_sel_list), table['table_id'], sql_where_conds) |
|
|
|
elif len(str_orderby) != 0: |
|
|
|
final_str = 'SELECT %s FROM %s ORDER BY %s' % ( |
|
|
|
', '.join(str_sel_list), table['table_name'], str_orderby) |
|
|
|
final_sql = 'SELECT %s FROM `%s` ORDER BY %s' % ( |
|
|
|
', '.join(sql_sel_list), table['table_id'], sql_orderby) |
|
|
|
else: |
|
|
|
final_str = 'SELECT %s FROM %s' % (', '.join(str_sel_list), |
|
|
|
table['table_name']) |
|
|
|
@@ -300,10 +342,10 @@ class TableQuestionAnsweringPipeline(Pipeline): |
|
|
|
cursor = self.db.connection_obj.cursor().execute(sql.query) |
|
|
|
header_ids, header_names = [], [] |
|
|
|
for description in cursor.description: |
|
|
|
header_ids.append(self.db.tables[result['table_id']] |
|
|
|
['headerid2name'].get( |
|
|
|
description[0], description[0])) |
|
|
|
header_names.append(description[0]) |
|
|
|
header_names.append(self.db.tables[result['table_id']] |
|
|
|
['headerid2name'].get( |
|
|
|
description[0], description[0])) |
|
|
|
header_ids.append(description[0]) |
|
|
|
rows = [] |
|
|
|
for res in cursor.fetchall(): |
|
|
|
rows.append(list(res)) |
|
|
|
|