Browse Source

[to #42322933] add ut for multi threads

1. 修复multi thread引起的问题
2. 增加multi thread的unittest
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10502008
master
caorongyu.cry yingda.chen 3 years ago
parent
commit
6178f46910
3 changed files with 29 additions and 4 deletions
  1. +5
    -1
      modelscope/pipelines/nlp/table_question_answering_pipeline.py
  2. +2
    -1
      modelscope/preprocessors/space_T_cn/fields/database.py
  3. +22
    -2
      tests/pipelines/test_table_question_answering.py

+ 5
- 1
modelscope/pipelines/nlp/table_question_answering_pipeline.py View File

@@ -17,6 +17,9 @@ from modelscope.preprocessors.space_T_cn.fields.database import Database
from modelscope.preprocessors.space_T_cn.fields.struct import (Constant, from modelscope.preprocessors.space_T_cn.fields.struct import (Constant,
SQLQuery) SQLQuery)
from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


__all__ = ['TableQuestionAnsweringPipeline'] __all__ = ['TableQuestionAnsweringPipeline']


@@ -309,7 +312,8 @@ class TableQuestionAnsweringPipeline(Pipeline):
'header_name': header_names, 'header_name': header_names,
'rows': rows 'rows': rows
} }
except Exception:
except Exception as e:
logger.error(e)
tabledata = {'header_id': [], 'header_name': [], 'rows': []} tabledata = {'header_id': [], 'header_name': [], 'rows': []}
else: else:
tabledata = {'header_id': [], 'header_name': [], 'rows': []} tabledata = {'header_id': [], 'header_name': [], 'rows': []}


+ 2
- 1
modelscope/preprocessors/space_T_cn/fields/database.py View File

@@ -17,7 +17,8 @@ class Database:
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.is_use_sqlite = is_use_sqlite self.is_use_sqlite = is_use_sqlite
if self.is_use_sqlite: if self.is_use_sqlite:
self.connection_obj = sqlite3.connect(':memory:')
self.connection_obj = sqlite3.connect(
':memory:', check_same_thread=False)
self.type_dict = {'text': 'TEXT', 'number': 'INT', 'date': 'TEXT'} self.type_dict = {'text': 'TEXT', 'number': 'INT', 'date': 'TEXT'}
self.tables = self.init_tables(table_file_path=table_file_path) self.tables = self.init_tables(table_file_path=table_file_path)
self.syn_dict = self.init_syn_dict( self.syn_dict = self.init_syn_dict(


+ 22
- 2
tests/pipelines/test_table_question_answering.py View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import os import os
import unittest import unittest
from threading import Thread
from typing import List from typing import List


import json import json
@@ -108,8 +109,6 @@ class TableQuestionAnswering(unittest.TestCase):
self.task = Tasks.table_question_answering self.task = Tasks.table_question_answering
self.model_id = 'damo/nlp_convai_text2sql_pretrain_cn' self.model_id = 'damo/nlp_convai_text2sql_pretrain_cn'


model_id = 'damo/nlp_convai_text2sql_pretrain_cn'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self): def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id) cache_path = snapshot_download(self.model_id)
@@ -122,6 +121,27 @@ class TableQuestionAnswering(unittest.TestCase):
] ]
tableqa_tracking_and_print_results_with_history(pipelines) tableqa_tracking_and_print_results_with_history(pipelines)


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download_with_multithreads(self):
cache_path = snapshot_download(self.model_id)
pl = pipeline(Tasks.table_question_answering, model=cache_path)

def print_func(pl, i):
result = pl({
'question': '长江流域的小(2)型水库的库容总量是多少?',
'table_id': 'reservoir',
'history_sql': None
})
print(i, json.dumps(result))

procs = []
for i in range(5):
proc = Thread(target=print_func, args=(pl, i))
procs.append(proc)
proc.start()
for proc in procs:
proc.join()

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self): def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id) model = Model.from_pretrained(self.model_id)


Loading…
Cancel
Save