| @@ -12,7 +12,6 @@ from http.cookiejar import CookieJar | |||||
| from os.path import expanduser | from os.path import expanduser | ||||
| from typing import List, Optional, Tuple, Union | from typing import List, Optional, Tuple, Union | ||||
| import attrs | |||||
| import requests | import requests | ||||
| from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | ||||
| @@ -22,14 +21,9 @@ from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||||
| API_RESPONSE_FIELD_USERNAME, | API_RESPONSE_FIELD_USERNAME, | ||||
| DEFAULT_CREDENTIALS_PATH, Licenses, | DEFAULT_CREDENTIALS_PATH, Licenses, | ||||
| ModelVisibility) | ModelVisibility) | ||||
| from modelscope.hub.deploy import (DeleteServiceParameters, | |||||
| DeployServiceParameters, | |||||
| GetServiceParameters, ListServiceParameters, | |||||
| ServiceParameters, ServiceResourceConfig, | |||||
| Vendor) | |||||
| from modelscope.hub.errors import (InvalidParameter, NotExistError, | from modelscope.hub.errors import (InvalidParameter, NotExistError, | ||||
| NotLoginException, NotSupportError, | |||||
| RequestError, datahub_raise_on_error, | |||||
| NotLoginException, RequestError, | |||||
| datahub_raise_on_error, | |||||
| handle_http_post_error, | handle_http_post_error, | ||||
| handle_http_response, is_ok, raise_on_error) | handle_http_response, is_ok, raise_on_error) | ||||
| from modelscope.hub.git import GitCommandWrapper | from modelscope.hub.git import GitCommandWrapper | ||||
| @@ -312,169 +306,6 @@ class HubApi: | |||||
| r.raise_for_status() | r.raise_for_status() | ||||
| return None | return None | ||||
| def deploy_model(self, model_id: str, revision: str, instance_name: str, | |||||
| resource: ServiceResourceConfig, | |||||
| provider: ServiceParameters): | |||||
| """Deploy model to cloud, current we only support PAI EAS, this is asynchronous | |||||
| call , please check instance status through the console or query the instance status. | |||||
| At the same time, this call may take a long time. | |||||
| Args: | |||||
| model_id (str): The deployed model id | |||||
| revision (str): The model revision | |||||
| instance_name (str): The deployed model instance name. | |||||
| resource (DeployResource): The resource information. | |||||
| provider (CreateParameter): The cloud service provider parameter | |||||
| Raises: | |||||
| NotLoginException: To use this api, you need login first. | |||||
| NotSupportError: Not supported platform. | |||||
| RequestError: The server return error. | |||||
| Returns: | |||||
| InstanceInfo: The instance information. | |||||
| """ | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| if cookies is None: | |||||
| raise NotLoginException( | |||||
| 'Token does not exist, please login first.') | |||||
| if provider.vendor != Vendor.EAS: | |||||
| raise NotSupportError( | |||||
| 'Not support vendor: %s ,only support EAS current.' % | |||||
| (provider.vendor)) | |||||
| create_params = DeployServiceParameters( | |||||
| instance_name=instance_name, | |||||
| model_id=model_id, | |||||
| revision=revision, | |||||
| resource=resource, | |||||
| provider=provider) | |||||
| path = f'{self.endpoint}/api/v1/deployer/endpoint' | |||||
| body = attrs.asdict(create_params) | |||||
| r = requests.post( | |||||
| path, | |||||
| json=body, | |||||
| cookies=cookies, | |||||
| ) | |||||
| handle_http_response(r, logger, cookies, 'create_eas_instance') | |||||
| if r.status_code >= HTTPStatus.OK and r.status_code < HTTPStatus.MULTIPLE_CHOICES: | |||||
| if is_ok(r.json()): | |||||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||||
| return data | |||||
| else: | |||||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||||
| else: | |||||
| r.raise_for_status() | |||||
| return None | |||||
| def list_deployed_model_instances(self, | |||||
| provider: ServiceParameters, | |||||
| skip: int = 0, | |||||
| limit: int = 100): | |||||
| """List deployed model instances. | |||||
| Args: | |||||
| provider (ListServiceParameter): The cloud service provider parameter, | |||||
| for eas, need access_key_id and access_key_secret. | |||||
| skip: start of the list, current not support. | |||||
| limit: maximum number of instances return, current not support | |||||
| Raises: | |||||
| NotLoginException: To use this api, you need login first. | |||||
| RequestError: The request is failed from server. | |||||
| Returns: | |||||
| List: List of instance information | |||||
| """ | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| if cookies is None: | |||||
| raise NotLoginException( | |||||
| 'Token does not exist, please login first.') | |||||
| params = ListServiceParameters( | |||||
| provider=provider, skip=skip, limit=limit) | |||||
| path = '%s/api/v1/deployer/endpoint?%s' % (self.endpoint, | |||||
| params.to_query_str()) | |||||
| r = requests.get(path, cookies=cookies) | |||||
| handle_http_response(r, logger, cookies, 'list_deployed_model') | |||||
| if r.status_code == HTTPStatus.OK: | |||||
| if is_ok(r.json()): | |||||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||||
| return data | |||||
| else: | |||||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||||
| else: | |||||
| r.raise_for_status() | |||||
| return None | |||||
| def get_deployed_model_instance(self, instance_name: str, | |||||
| provider: ServiceParameters): | |||||
| """Query the specified instance information. | |||||
| Args: | |||||
| instance_name (str): The deployed instance name. | |||||
| provider (GetParameter): The cloud provider information, for eas | |||||
| need region(eg: ch-hangzhou), access_key_id and access_key_secret. | |||||
| Raises: | |||||
| NotLoginException: To use this api, you need login first. | |||||
| RequestError: The request is failed from server. | |||||
| Returns: | |||||
| Dict: The request instance information | |||||
| """ | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| if cookies is None: | |||||
| raise NotLoginException( | |||||
| 'Token does not exist, please login first.') | |||||
| params = GetServiceParameters(provider=provider) | |||||
| path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | |||||
| self.endpoint, instance_name, params.to_query_str()) | |||||
| r = requests.get(path, cookies=cookies) | |||||
| handle_http_response(r, logger, cookies, 'get_deployed_model') | |||||
| if r.status_code == HTTPStatus.OK: | |||||
| if is_ok(r.json()): | |||||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||||
| return data | |||||
| else: | |||||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||||
| else: | |||||
| r.raise_for_status() | |||||
| return None | |||||
| def delete_deployed_model_instance(self, instance_name: str, | |||||
| provider: ServiceParameters): | |||||
| """Delete deployed model, this api send delete command and return, it will take | |||||
| some to delete, please check through the cloud console. | |||||
| Args: | |||||
| instance_name (str): The instance name you want to delete. | |||||
| provider (DeleteParameter): The cloud provider information, for eas | |||||
| need region(eg: ch-hangzhou), access_key_id and access_key_secret. | |||||
| Raises: | |||||
| NotLoginException: To call this api, you need login first. | |||||
| RequestError: The request is failed. | |||||
| Returns: | |||||
| Dict: The deleted instance information. | |||||
| """ | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| if cookies is None: | |||||
| raise NotLoginException( | |||||
| 'Token does not exist, please login first.') | |||||
| params = DeleteServiceParameters(provider=provider) | |||||
| path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | |||||
| self.endpoint, instance_name, params.to_query_str()) | |||||
| r = requests.delete(path, cookies=cookies) | |||||
| handle_http_response(r, logger, cookies, 'delete_deployed_model') | |||||
| if r.status_code == HTTPStatus.OK: | |||||
| if is_ok(r.json()): | |||||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||||
| return data | |||||
| else: | |||||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||||
| else: | |||||
| r.raise_for_status() | |||||
| return None | |||||
| def _check_cookie(self, | def _check_cookie(self, | ||||
| use_cookies: Union[bool, | use_cookies: Union[bool, | ||||
| CookieJar] = False) -> CookieJar: | CookieJar] = False) -> CookieJar: | ||||
| @@ -1,11 +1,25 @@ | |||||
| import urllib | import urllib | ||||
| from abc import ABC, abstractmethod | |||||
| from typing import Optional, Union | |||||
| from abc import ABC | |||||
| from http import HTTPStatus | |||||
| from typing import Optional | |||||
| import attrs | |||||
| import json | import json | ||||
| from attr import fields | |||||
| import requests | |||||
| from attrs import asdict, define, field, validators | from attrs import asdict, define, field, validators | ||||
| from modelscope.hub.api import ModelScopeConfig | |||||
| from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||||
| API_RESPONSE_FIELD_MESSAGE) | |||||
| from modelscope.hub.errors import (NotLoginException, NotSupportError, | |||||
| RequestError, handle_http_response, is_ok) | |||||
| from modelscope.hub.utils.utils import get_endpoint | |||||
| from modelscope.utils.logger import get_logger | |||||
| # yapf: enable | |||||
| logger = get_logger() | |||||
| class Accelerator(object): | class Accelerator(object): | ||||
| CPU = 'cpu' | CPU = 'cpu' | ||||
| @@ -76,12 +90,12 @@ class ServiceResourceConfig(object): | |||||
| @define | @define | ||||
| class ServiceParameters(ABC): | |||||
| class ServiceProviderParameters(ABC): | |||||
| pass | pass | ||||
| @define | @define | ||||
| class EASDeployParameters(ServiceParameters): | |||||
| class EASDeployParameters(ServiceProviderParameters): | |||||
| """Parameters for EAS Deployment. | """Parameters for EAS Deployment. | ||||
| Args: | Args: | ||||
| @@ -97,29 +111,10 @@ class EASDeployParameters(ServiceParameters): | |||||
| resource_group: Optional[str] = None | resource_group: Optional[str] = None | ||||
| vendor: str = field( | vendor: str = field( | ||||
| default=Vendor.EAS, validator=validators.in_([Vendor.EAS])) | default=Vendor.EAS, validator=validators.in_([Vendor.EAS])) | ||||
| """ | |||||
| def __init__(self, | |||||
| instance_name: str, | |||||
| access_key_id: str, | |||||
| access_key_secret: str, | |||||
| region = EASRegion.beijing, | |||||
| instance_type: str = EASCpuInstances.small, | |||||
| accelerator: str = Accelerator.CPU, | |||||
| resource_group: Optional[str] = None, | |||||
| scaling: Optional[str] = None): | |||||
| self.instance_name=instance_name | |||||
| self.access_key_id=self.access_key_id | |||||
| self.access_key_secret = access_key_secret | |||||
| self.region = region | |||||
| self.instance_type = instance_type | |||||
| self.accelerator = accelerator | |||||
| self.resource_group = resource_group | |||||
| self.scaling = scaling | |||||
| """ | |||||
| @define | @define | ||||
| class EASListParameters(ServiceParameters): | |||||
| class EASListParameters(ServiceProviderParameters): | |||||
| """EAS instance list parameters. | """EAS instance list parameters. | ||||
| Args: | Args: | ||||
| @@ -152,7 +147,7 @@ class DeployServiceParameters(object): | |||||
| model_id: str | model_id: str | ||||
| revision: str | revision: str | ||||
| resource: ServiceResourceConfig | resource: ServiceResourceConfig | ||||
| provider: ServiceParameters | |||||
| provider: ServiceProviderParameters | |||||
| class AttrsToQueryString(ABC): | class AttrsToQueryString(ABC): | ||||
| @@ -174,16 +169,173 @@ class AttrsToQueryString(ABC): | |||||
| @define | @define | ||||
| class ListServiceParameters(AttrsToQueryString): | class ListServiceParameters(AttrsToQueryString): | ||||
| provider: ServiceParameters | |||||
| provider: ServiceProviderParameters | |||||
| skip: int = 0 | skip: int = 0 | ||||
| limit: int = 100 | limit: int = 100 | ||||
| @define | @define | ||||
| class GetServiceParameters(AttrsToQueryString): | class GetServiceParameters(AttrsToQueryString): | ||||
| provider: ServiceParameters | |||||
| provider: ServiceProviderParameters | |||||
| @define | @define | ||||
| class DeleteServiceParameters(AttrsToQueryString): | class DeleteServiceParameters(AttrsToQueryString): | ||||
| provider: ServiceParameters | |||||
| provider: ServiceProviderParameters | |||||
| class ServiceDeployer(object): | |||||
| def __init__(self, endpoint=None): | |||||
| self.endpoint = endpoint if endpoint is not None else get_endpoint() | |||||
| self.cookies = ModelScopeConfig.get_cookies() | |||||
| if self.cookies is None: | |||||
| raise NotLoginException( | |||||
| 'Token does not exist, please login with HubApi first.') | |||||
| # deploy_model | |||||
| def create(self, model_id: str, revision: str, instance_name: str, | |||||
| resource: ServiceResourceConfig, | |||||
| provider: ServiceProviderParameters): | |||||
| """Deploy model to cloud, current we only support PAI EAS, this is an async API , | |||||
| and the deployment could take a while to finish remotely. Please check deploy instance | |||||
| status separately via checking the status. | |||||
| Args: | |||||
| model_id (str): The deployed model id | |||||
| revision (str): The model revision | |||||
| instance_name (str): The deployed model instance name. | |||||
| resource (ServiceResourceConfig): The service resource information. | |||||
| provider (ServiceProviderParameters): The service provider parameter | |||||
| Raises: | |||||
| NotLoginException: To use this api, you need login first. | |||||
| NotSupportError: Not supported platform. | |||||
| RequestError: The server return error. | |||||
| Returns: | |||||
| ServiceInstanceInfo: The information of the deployed service instance. | |||||
| """ | |||||
| if provider.vendor != Vendor.EAS: | |||||
| raise NotSupportError( | |||||
| 'Not support vendor: %s ,only support EAS current.' % | |||||
| (provider.vendor)) | |||||
| create_params = DeployServiceParameters( | |||||
| instance_name=instance_name, | |||||
| model_id=model_id, | |||||
| revision=revision, | |||||
| resource=resource, | |||||
| provider=provider) | |||||
| path = f'{self.endpoint}/api/v1/deployer/endpoint' | |||||
| body = attrs.asdict(create_params) | |||||
| r = requests.post( | |||||
| path, | |||||
| json=body, | |||||
| cookies=self.cookies, | |||||
| ) | |||||
| handle_http_response(r, logger, self.cookies, 'create_service') | |||||
| if r.status_code >= HTTPStatus.OK and r.status_code < HTTPStatus.MULTIPLE_CHOICES: | |||||
| if is_ok(r.json()): | |||||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||||
| return data | |||||
| else: | |||||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||||
| else: | |||||
| r.raise_for_status() | |||||
| return None | |||||
| def get(self, instance_name: str, provider: ServiceProviderParameters): | |||||
| """Query the specified instance information. | |||||
| Args: | |||||
| instance_name (str): The deployed instance name. | |||||
| provider (ServiceProviderParameters): The cloud provider information, for eas | |||||
| need region(eg: ch-hangzhou), access_key_id and access_key_secret. | |||||
| Raises: | |||||
| NotLoginException: To use this api, you need login first. | |||||
| RequestError: The request is failed from server. | |||||
| Returns: | |||||
| Dict: The information of the requested service instance. | |||||
| """ | |||||
| params = GetServiceParameters(provider=provider) | |||||
| path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | |||||
| self.endpoint, instance_name, params.to_query_str()) | |||||
| r = requests.get(path, cookies=self.cookies) | |||||
| handle_http_response(r, logger, self.cookies, 'get_service') | |||||
| if r.status_code == HTTPStatus.OK: | |||||
| if is_ok(r.json()): | |||||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||||
| return data | |||||
| else: | |||||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||||
| else: | |||||
| r.raise_for_status() | |||||
| return None | |||||
| def delete(self, instance_name: str, provider: ServiceProviderParameters): | |||||
| """Delete deployed model, this api send delete command and return, it will take | |||||
| some to delete, please check through the cloud console. | |||||
| Args: | |||||
| instance_name (str): The instance name you want to delete. | |||||
| provider (ServiceProviderParameters): The cloud provider information, for eas | |||||
| need region(eg: ch-hangzhou), access_key_id and access_key_secret. | |||||
| Raises: | |||||
| NotLoginException: To call this api, you need login first. | |||||
| RequestError: The request is failed. | |||||
| Returns: | |||||
| Dict: The deleted instance information. | |||||
| """ | |||||
| params = DeleteServiceParameters(provider=provider) | |||||
| path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | |||||
| self.endpoint, instance_name, params.to_query_str()) | |||||
| r = requests.delete(path, cookies=self.cookies) | |||||
| handle_http_response(r, logger, self.cookies, 'delete_service') | |||||
| if r.status_code == HTTPStatus.OK: | |||||
| if is_ok(r.json()): | |||||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||||
| return data | |||||
| else: | |||||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||||
| else: | |||||
| r.raise_for_status() | |||||
| return None | |||||
| def list(self, | |||||
| provider: ServiceProviderParameters, | |||||
| skip: int = 0, | |||||
| limit: int = 100): | |||||
| """List deployed model instances. | |||||
| Args: | |||||
| provider (ServiceProviderParameters): The cloud service provider parameter, | |||||
| for eas, need access_key_id and access_key_secret. | |||||
| skip: start of the list, current not support. | |||||
| limit: maximum number of instances return, current not support | |||||
| Raises: | |||||
| NotLoginException: To use this api, you need login first. | |||||
| RequestError: The request is failed from server. | |||||
| Returns: | |||||
| List: List of instance information | |||||
| """ | |||||
| params = ListServiceParameters( | |||||
| provider=provider, skip=skip, limit=limit) | |||||
| path = '%s/api/v1/deployer/endpoint?%s' % (self.endpoint, | |||||
| params.to_query_str()) | |||||
| r = requests.get(path, cookies=self.cookies) | |||||
| handle_http_response(r, logger, self.cookies, 'list_service_instances') | |||||
| if r.status_code == HTTPStatus.OK: | |||||
| if is_ok(r.json()): | |||||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||||
| return data | |||||
| else: | |||||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||||
| else: | |||||
| r.raise_for_status() | |||||
| return None | |||||
| @@ -49,7 +49,7 @@ class FaceRecognitionPipeline(Pipeline): | |||||
| # face detect pipeline | # face detect pipeline | ||||
| det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' | det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' | ||||
| self.face_detection = pipeline( | self.face_detection = pipeline( | ||||
| Tasks.face_detection, model=det_model_id, model_revision='v2') | |||||
| Tasks.face_detection, model=det_model_id) | |||||
| def _choose_face(self, | def _choose_face(self, | ||||
| det_result, | det_result, | ||||
| @@ -17,6 +17,9 @@ from modelscope.preprocessors.space_T_cn.fields.database import Database | |||||
| from modelscope.preprocessors.space_T_cn.fields.struct import (Constant, | from modelscope.preprocessors.space_T_cn.fields.struct import (Constant, | ||||
| SQLQuery) | SQLQuery) | ||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| __all__ = ['TableQuestionAnsweringPipeline'] | __all__ = ['TableQuestionAnsweringPipeline'] | ||||
| @@ -309,7 +312,8 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||||
| 'header_name': header_names, | 'header_name': header_names, | ||||
| 'rows': rows | 'rows': rows | ||||
| } | } | ||||
| except Exception: | |||||
| except Exception as e: | |||||
| logger.error(e) | |||||
| tabledata = {'header_id': [], 'header_name': [], 'rows': []} | tabledata = {'header_id': [], 'header_name': [], 'rows': []} | ||||
| else: | else: | ||||
| tabledata = {'header_id': [], 'header_name': [], 'rows': []} | tabledata = {'header_id': [], 'header_name': [], 'rows': []} | ||||
| @@ -17,7 +17,8 @@ class Database: | |||||
| self.tokenizer = tokenizer | self.tokenizer = tokenizer | ||||
| self.is_use_sqlite = is_use_sqlite | self.is_use_sqlite = is_use_sqlite | ||||
| if self.is_use_sqlite: | if self.is_use_sqlite: | ||||
| self.connection_obj = sqlite3.connect(':memory:') | |||||
| self.connection_obj = sqlite3.connect( | |||||
| ':memory:', check_same_thread=False) | |||||
| self.type_dict = {'text': 'TEXT', 'number': 'INT', 'date': 'TEXT'} | self.type_dict = {'text': 'TEXT', 'number': 'INT', 'date': 'TEXT'} | ||||
| self.tables = self.init_tables(table_file_path=table_file_path) | self.tables = self.init_tables(table_file_path=table_file_path) | ||||
| self.syn_dict = self.init_syn_dict( | self.syn_dict = self.init_syn_dict( | ||||
| @@ -28,8 +28,7 @@ class FaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| input_location = ['data/test/images/face_detection2.jpeg'] | input_location = ['data/test/images/face_detection2.jpeg'] | ||||
| dataset = MsDataset.load(input_location, target='image') | dataset = MsDataset.load(input_location, target='image') | ||||
| face_detection = pipeline( | |||||
| Tasks.face_detection, model=self.model_id, model_revision='v2') | |||||
| face_detection = pipeline(Tasks.face_detection, model=self.model_id) | |||||
| # note that for dataset output, the inference-output is a Generator that can be iterated. | # note that for dataset output, the inference-output is a Generator that can be iterated. | ||||
| result = face_detection(dataset) | result = face_detection(dataset) | ||||
| result = next(result) | result = next(result) | ||||
| @@ -37,8 +36,7 @@ class FaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_run_modelhub(self): | def test_run_modelhub(self): | ||||
| face_detection = pipeline( | |||||
| Tasks.face_detection, model=self.model_id, model_revision='v2') | |||||
| face_detection = pipeline(Tasks.face_detection, model=self.model_id) | |||||
| img_path = 'data/test/images/face_detection2.jpeg' | img_path = 'data/test/images/face_detection2.jpeg' | ||||
| result = face_detection(img_path) | result = face_detection(img_path) | ||||
| @@ -1,6 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | import os | ||||
| import unittest | import unittest | ||||
| from threading import Thread | |||||
| from typing import List | from typing import List | ||||
| import json | import json | ||||
| @@ -108,8 +109,6 @@ class TableQuestionAnswering(unittest.TestCase): | |||||
| self.task = Tasks.table_question_answering | self.task = Tasks.table_question_answering | ||||
| self.model_id = 'damo/nlp_convai_text2sql_pretrain_cn' | self.model_id = 'damo/nlp_convai_text2sql_pretrain_cn' | ||||
| model_id = 'damo/nlp_convai_text2sql_pretrain_cn' | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_run_by_direct_model_download(self): | def test_run_by_direct_model_download(self): | ||||
| cache_path = snapshot_download(self.model_id) | cache_path = snapshot_download(self.model_id) | ||||
| @@ -122,6 +121,27 @@ class TableQuestionAnswering(unittest.TestCase): | |||||
| ] | ] | ||||
| tableqa_tracking_and_print_results_with_history(pipelines) | tableqa_tracking_and_print_results_with_history(pipelines) | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_by_direct_model_download_with_multithreads(self): | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| pl = pipeline(Tasks.table_question_answering, model=cache_path) | |||||
| def print_func(pl, i): | |||||
| result = pl({ | |||||
| 'question': '长江流域的小(2)型水库的库容总量是多少?', | |||||
| 'table_id': 'reservoir', | |||||
| 'history_sql': None | |||||
| }) | |||||
| print(i, json.dumps(result)) | |||||
| procs = [] | |||||
| for i in range(5): | |||||
| proc = Thread(target=print_func, args=(pl, i)) | |||||
| procs.append(proc) | |||||
| proc.start() | |||||
| for proc in procs: | |||||
| proc.join() | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_run_with_model_from_modelhub(self): | def test_run_with_model_from_modelhub(self): | ||||
| model = Model.from_pretrained(self.model_id) | model = Model.from_pretrained(self.model_id) | ||||
| @@ -28,7 +28,7 @@ def _setup(): | |||||
| val_root = val_dir + '/' + os.listdir(val_dir)[0] + '/' | val_root = val_dir + '/' + os.listdir(val_dir)[0] + '/' | ||||
| max_epochs = 1 # run epochs in unit test | max_epochs = 1 # run epochs in unit test | ||||
| cache_path = snapshot_download(model_id, revision='v2') | |||||
| cache_path = snapshot_download(model_id) | |||||
| tmp_dir = tempfile.TemporaryDirectory().name | tmp_dir = tempfile.TemporaryDirectory().name | ||||
| if not os.path.exists(tmp_dir): | if not os.path.exists(tmp_dir): | ||||
| @@ -34,14 +34,14 @@ class ImageDenoiseTrainerTest(unittest.TestCase): | |||||
| 'SIDD', | 'SIDD', | ||||
| namespace='huizheng', | namespace='huizheng', | ||||
| subset_name='default', | subset_name='default', | ||||
| split='validation', | |||||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds | |||||
| split='test', | |||||
| download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds | |||||
| dataset_val = MsDataset.load( | dataset_val = MsDataset.load( | ||||
| 'SIDD', | 'SIDD', | ||||
| namespace='huizheng', | namespace='huizheng', | ||||
| subset_name='default', | subset_name='default', | ||||
| split='test', | split='test', | ||||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds | |||||
| download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds | |||||
| self.dataset_train = SiddImageDenoisingDataset( | self.dataset_train = SiddImageDenoisingDataset( | ||||
| dataset_train, self.config.dataset, is_train=True) | dataset_train, self.config.dataset, is_train=True) | ||||
| self.dataset_val = SiddImageDenoisingDataset( | self.dataset_val = SiddImageDenoisingDataset( | ||||
| @@ -51,7 +51,7 @@ class ImageDenoiseTrainerTest(unittest.TestCase): | |||||
| shutil.rmtree(self.tmp_dir, ignore_errors=True) | shutil.rmtree(self.tmp_dir, ignore_errors=True) | ||||
| super().tearDown() | super().tearDown() | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_trainer(self): | def test_trainer(self): | ||||
| kwargs = dict( | kwargs = dict( | ||||
| model=self.model_id, | model=self.model_id, | ||||
| @@ -65,7 +65,7 @@ class ImageDenoiseTrainerTest(unittest.TestCase): | |||||
| for i in range(2): | for i in range(2): | ||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | self.assertIn(f'epoch_{i+1}.pth', results_files) | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_trainer_with_model_and_args(self): | def test_trainer_with_model_and_args(self): | ||||
| model = NAFNetForImageDenoise.from_pretrained(self.cache_path) | model = NAFNetForImageDenoise.from_pretrained(self.cache_path) | ||||
| kwargs = dict( | kwargs = dict( | ||||