From 3b8fb92c136f9c139e7dfad2de0d164b413290a4 Mon Sep 17 00:00:00 2001 From: "caorongyu.cry" Date: Wed, 26 Oct 2022 16:04:14 +0800 Subject: [PATCH] [to #42322933] debug header ids and header names MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复header_ids和header_names命名反了的问题 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10516557 --- .../nlp/table_question_answering_pipeline.py | 66 +++++++++++++++---- .../nlp/space_T_cn/fields/struct.py | 22 +++++++ .../test_table_question_answering.py | 8 ++- 3 files changed, 81 insertions(+), 15 deletions(-) diff --git a/modelscope/pipelines/nlp/table_question_answering_pipeline.py b/modelscope/pipelines/nlp/table_question_answering_pipeline.py index 826e35a9..b75a8153 100644 --- a/modelscope/pipelines/nlp/table_question_answering_pipeline.py +++ b/modelscope/pipelines/nlp/table_question_answering_pipeline.py @@ -69,6 +69,7 @@ class TableQuestionAnsweringPipeline(Pipeline): self.max_where_num = constant.max_where_num self.col_type_dict = constant.col_type_dict self.schema_link_dict = constant.schema_link_dict + self.limit_dict = constant.limit_dict super().__init__(model=model, preprocessor=preprocessor, **kwargs) @@ -244,7 +245,13 @@ class TableQuestionAnsweringPipeline(Pipeline): + header_id + ')') str_cond_list, sql_cond_list = [], [] + where_conds, orderby_conds = [], [] for cond in sql['conds']: + if cond[1] in [4, 5]: + orderby_conds.append(cond) + else: + where_conds.append(cond) + for cond in where_conds: header_name = header_names[cond[0]] if header_name == '空列': continue @@ -255,14 +262,49 @@ class TableQuestionAnsweringPipeline(Pipeline): + '" )') sql_cond_list.append('( ' + header_id + ' ' + op + ' "' + value + '" )') - - cond = ' ' + self.cond_conn_ops[sql['cond_conn_op']] + ' ' - - if len(str_cond_list) != 0: - final_str = 'SELECT %s FROM %s WHERE %s' % (', '.join( - str_sel_list), table['table_name'], cond.join(str_cond_list)) - final_sql = 'SELECT %s FROM `%s` WHERE %s' % (', '.join( - sql_sel_list), table['table_id'], cond.join(sql_cond_list)) + cond_str = ' ' + self.cond_conn_ops[sql['cond_conn_op']] + ' ' + str_where_conds = cond_str.join(str_cond_list) + sql_where_conds = cond_str.join(sql_cond_list) + if len(orderby_conds) != 0: + str_orderby_column = ', '.join( + [header_names[cond[0]] for cond in orderby_conds]) + sql_orderby_column = ', '.join([ + '`%s`.`%s`' % (table['table_id'], header_ids[cond[0]]) + for cond in orderby_conds + ]) + str_orderby_op = self.cond_ops[orderby_conds[0][1]] + str_orderby = '%s %s' % (str_orderby_column, str_orderby_op) + sql_orderby = '%s %s' % (sql_orderby_column, str_orderby_op) + limit_key = orderby_conds[0][2] + is_in, limit_num = False, -1 + for key in self.limit_dict: + if key in limit_key: + is_in = True + limit_num = self.limit_dict[key] + break + if is_in: + str_orderby += ' LIMIT %d' % (limit_num) + sql_orderby += ' LIMIT %d' % (limit_num) + else: + str_orderby = '' + + if len(str_cond_list) != 0 and len(str_orderby) != 0: + final_str = 'SELECT %s FROM %s WHERE %s ORDER BY %s' % ( + ', '.join(str_sel_list), table['table_name'], str_where_conds, + str_orderby) + final_sql = 'SELECT %s FROM `%s` WHERE %s ORDER BY %s' % ( + ', '.join(sql_sel_list), table['table_id'], sql_where_conds, + sql_orderby) + elif len(str_cond_list) != 0: + final_str = 'SELECT %s FROM %s WHERE %s' % ( + ', '.join(str_sel_list), table['table_name'], str_where_conds) + final_sql = 'SELECT %s FROM `%s` WHERE %s' % ( + ', '.join(sql_sel_list), table['table_id'], sql_where_conds) + elif len(str_orderby) != 0: + final_str = 'SELECT %s FROM %s ORDER BY %s' % ( + ', '.join(str_sel_list), table['table_name'], str_orderby) + final_sql = 'SELECT %s FROM `%s` ORDER BY %s' % ( + ', '.join(sql_sel_list), table['table_id'], sql_orderby) else: final_str = 'SELECT %s FROM %s' % (', '.join(str_sel_list), table['table_name']) @@ -300,10 +342,10 @@ class TableQuestionAnsweringPipeline(Pipeline): cursor = self.db.connection_obj.cursor().execute(sql.query) header_ids, header_names = [], [] for description in cursor.description: - header_ids.append(self.db.tables[result['table_id']] - ['headerid2name'].get( - description[0], description[0])) - header_names.append(description[0]) + header_names.append(self.db.tables[result['table_id']] + ['headerid2name'].get( + description[0], description[0])) + header_ids.append(description[0]) rows = [] for res in cursor.fetchall(): rows.append(list(res)) diff --git a/modelscope/preprocessors/nlp/space_T_cn/fields/struct.py b/modelscope/preprocessors/nlp/space_T_cn/fields/struct.py index 3c2e664b..917e1aaa 100644 --- a/modelscope/preprocessors/nlp/space_T_cn/fields/struct.py +++ b/modelscope/preprocessors/nlp/space_T_cn/fields/struct.py @@ -179,3 +179,25 @@ class Constant: self.max_select_num = 4 self.max_where_num = 6 + + self.limit_dict = { + '最': 1, + '1': 1, + '一': 1, + '2': 2, + '二': 2, + '3': 3, + '三': 3, + '4': 4, + '四': 4, + '5': 5, + '五': 5, + '6': 6, + '六': 6, + '7': 7, + '七': 7, + '8': 8, + '八': 8, + '9': 9, + '九': 9 + } diff --git a/tests/pipelines/test_table_question_answering.py b/tests/pipelines/test_table_question_answering.py index eece7f57..825d8f23 100644 --- a/tests/pipelines/test_table_question_answering.py +++ b/tests/pipelines/test_table_question_answering.py @@ -128,11 +128,13 @@ class TableQuestionAnswering(unittest.TestCase): def print_func(pl, i): result = pl({ - 'question': '长江流域的小(2)型水库的库容总量是多少?', - 'table_id': 'reservoir', + 'question': '上个月收益从低到高排前七的基金的名称和风险等级是什么', + 'table_id': 'fund', 'history_sql': None }) - print(i, json.dumps(result)) + print(i, result[OutputKeys.OUTPUT][OutputKeys.SQL_QUERY], + result[OutputKeys.OUTPUT][OutputKeys.QUERT_RESULT], + json.dumps(result)) procs = [] for i in range(5):