diff --git a/modelscope/pipelines/nlp/table_question_answering_pipeline.py b/modelscope/pipelines/nlp/table_question_answering_pipeline.py index 52ba33e0..fc0d07b1 100644 --- a/modelscope/pipelines/nlp/table_question_answering_pipeline.py +++ b/modelscope/pipelines/nlp/table_question_answering_pipeline.py @@ -17,6 +17,9 @@ from modelscope.preprocessors.space_T_cn.fields.database import Database from modelscope.preprocessors.space_T_cn.fields.struct import (Constant, SQLQuery) from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() __all__ = ['TableQuestionAnsweringPipeline'] @@ -309,7 +312,8 @@ class TableQuestionAnsweringPipeline(Pipeline): 'header_name': header_names, 'rows': rows } - except Exception: + except Exception as e: + logger.error(e) tabledata = {'header_id': [], 'header_name': [], 'rows': []} else: tabledata = {'header_id': [], 'header_name': [], 'rows': []} diff --git a/modelscope/preprocessors/space_T_cn/fields/database.py b/modelscope/preprocessors/space_T_cn/fields/database.py index 481bd1db..7ae38ee2 100644 --- a/modelscope/preprocessors/space_T_cn/fields/database.py +++ b/modelscope/preprocessors/space_T_cn/fields/database.py @@ -17,7 +17,8 @@ class Database: self.tokenizer = tokenizer self.is_use_sqlite = is_use_sqlite if self.is_use_sqlite: - self.connection_obj = sqlite3.connect(':memory:') + self.connection_obj = sqlite3.connect( + ':memory:', check_same_thread=False) self.type_dict = {'text': 'TEXT', 'number': 'INT', 'date': 'TEXT'} self.tables = self.init_tables(table_file_path=table_file_path) self.syn_dict = self.init_syn_dict( diff --git a/tests/pipelines/test_table_question_answering.py b/tests/pipelines/test_table_question_answering.py index 828ef5ac..44f1531b 100644 --- a/tests/pipelines/test_table_question_answering.py +++ b/tests/pipelines/test_table_question_answering.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os import unittest +from threading import Thread from typing import List import json @@ -108,8 +109,6 @@ class TableQuestionAnswering(unittest.TestCase): self.task = Tasks.table_question_answering self.model_id = 'damo/nlp_convai_text2sql_pretrain_cn' - model_id = 'damo/nlp_convai_text2sql_pretrain_cn' - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_by_direct_model_download(self): cache_path = snapshot_download(self.model_id) @@ -122,6 +121,27 @@ class TableQuestionAnswering(unittest.TestCase): ] tableqa_tracking_and_print_results_with_history(pipelines) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download_with_multithreads(self): + cache_path = snapshot_download(self.model_id) + pl = pipeline(Tasks.table_question_answering, model=cache_path) + + def print_func(pl, i): + result = pl({ + 'question': '长江流域的小(2)型水库的库容总量是多少?', + 'table_id': 'reservoir', + 'history_sql': None + }) + print(i, json.dumps(result)) + + procs = [] + for i in range(5): + proc = Thread(target=print_func, args=(pl, i)) + procs.append(proc) + proc.start() + for proc in procs: + proc.join() + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_model_from_modelhub(self): model = Model.from_pretrained(self.model_id)