| @@ -12,7 +12,6 @@ from http.cookiejar import CookieJar | |||
| from os.path import expanduser | |||
| from typing import List, Optional, Tuple, Union | |||
| import attrs | |||
| import requests | |||
| from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||
| @@ -22,14 +21,9 @@ from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||
| API_RESPONSE_FIELD_USERNAME, | |||
| DEFAULT_CREDENTIALS_PATH, Licenses, | |||
| ModelVisibility) | |||
| from modelscope.hub.deploy import (DeleteServiceParameters, | |||
| DeployServiceParameters, | |||
| GetServiceParameters, ListServiceParameters, | |||
| ServiceParameters, ServiceResourceConfig, | |||
| Vendor) | |||
| from modelscope.hub.errors import (InvalidParameter, NotExistError, | |||
| NotLoginException, NotSupportError, | |||
| RequestError, datahub_raise_on_error, | |||
| NotLoginException, RequestError, | |||
| datahub_raise_on_error, | |||
| handle_http_post_error, | |||
| handle_http_response, is_ok, raise_on_error) | |||
| from modelscope.hub.git import GitCommandWrapper | |||
| @@ -312,169 +306,6 @@ class HubApi: | |||
| r.raise_for_status() | |||
| return None | |||
| def deploy_model(self, model_id: str, revision: str, instance_name: str, | |||
| resource: ServiceResourceConfig, | |||
| provider: ServiceParameters): | |||
| """Deploy model to cloud, current we only support PAI EAS, this is asynchronous | |||
| call , please check instance status through the console or query the instance status. | |||
| At the same time, this call may take a long time. | |||
| Args: | |||
| model_id (str): The deployed model id | |||
| revision (str): The model revision | |||
| instance_name (str): The deployed model instance name. | |||
| resource (DeployResource): The resource information. | |||
| provider (CreateParameter): The cloud service provider parameter | |||
| Raises: | |||
| NotLoginException: To use this api, you need login first. | |||
| NotSupportError: Not supported platform. | |||
| RequestError: The server return error. | |||
| Returns: | |||
| InstanceInfo: The instance information. | |||
| """ | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| if cookies is None: | |||
| raise NotLoginException( | |||
| 'Token does not exist, please login first.') | |||
| if provider.vendor != Vendor.EAS: | |||
| raise NotSupportError( | |||
| 'Not support vendor: %s ,only support EAS current.' % | |||
| (provider.vendor)) | |||
| create_params = DeployServiceParameters( | |||
| instance_name=instance_name, | |||
| model_id=model_id, | |||
| revision=revision, | |||
| resource=resource, | |||
| provider=provider) | |||
| path = f'{self.endpoint}/api/v1/deployer/endpoint' | |||
| body = attrs.asdict(create_params) | |||
| r = requests.post( | |||
| path, | |||
| json=body, | |||
| cookies=cookies, | |||
| ) | |||
| handle_http_response(r, logger, cookies, 'create_eas_instance') | |||
| if r.status_code >= HTTPStatus.OK and r.status_code < HTTPStatus.MULTIPLE_CHOICES: | |||
| if is_ok(r.json()): | |||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||
| return data | |||
| else: | |||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| r.raise_for_status() | |||
| return None | |||
| def list_deployed_model_instances(self, | |||
| provider: ServiceParameters, | |||
| skip: int = 0, | |||
| limit: int = 100): | |||
| """List deployed model instances. | |||
| Args: | |||
| provider (ListServiceParameter): The cloud service provider parameter, | |||
| for eas, need access_key_id and access_key_secret. | |||
| skip: start of the list, current not support. | |||
| limit: maximum number of instances return, current not support | |||
| Raises: | |||
| NotLoginException: To use this api, you need login first. | |||
| RequestError: The request is failed from server. | |||
| Returns: | |||
| List: List of instance information | |||
| """ | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| if cookies is None: | |||
| raise NotLoginException( | |||
| 'Token does not exist, please login first.') | |||
| params = ListServiceParameters( | |||
| provider=provider, skip=skip, limit=limit) | |||
| path = '%s/api/v1/deployer/endpoint?%s' % (self.endpoint, | |||
| params.to_query_str()) | |||
| r = requests.get(path, cookies=cookies) | |||
| handle_http_response(r, logger, cookies, 'list_deployed_model') | |||
| if r.status_code == HTTPStatus.OK: | |||
| if is_ok(r.json()): | |||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||
| return data | |||
| else: | |||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| r.raise_for_status() | |||
| return None | |||
| def get_deployed_model_instance(self, instance_name: str, | |||
| provider: ServiceParameters): | |||
| """Query the specified instance information. | |||
| Args: | |||
| instance_name (str): The deployed instance name. | |||
| provider (GetParameter): The cloud provider information, for eas | |||
| need region(eg: ch-hangzhou), access_key_id and access_key_secret. | |||
| Raises: | |||
| NotLoginException: To use this api, you need login first. | |||
| RequestError: The request is failed from server. | |||
| Returns: | |||
| Dict: The request instance information | |||
| """ | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| if cookies is None: | |||
| raise NotLoginException( | |||
| 'Token does not exist, please login first.') | |||
| params = GetServiceParameters(provider=provider) | |||
| path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | |||
| self.endpoint, instance_name, params.to_query_str()) | |||
| r = requests.get(path, cookies=cookies) | |||
| handle_http_response(r, logger, cookies, 'get_deployed_model') | |||
| if r.status_code == HTTPStatus.OK: | |||
| if is_ok(r.json()): | |||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||
| return data | |||
| else: | |||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| r.raise_for_status() | |||
| return None | |||
| def delete_deployed_model_instance(self, instance_name: str, | |||
| provider: ServiceParameters): | |||
| """Delete deployed model, this api send delete command and return, it will take | |||
| some to delete, please check through the cloud console. | |||
| Args: | |||
| instance_name (str): The instance name you want to delete. | |||
| provider (DeleteParameter): The cloud provider information, for eas | |||
| need region(eg: ch-hangzhou), access_key_id and access_key_secret. | |||
| Raises: | |||
| NotLoginException: To call this api, you need login first. | |||
| RequestError: The request is failed. | |||
| Returns: | |||
| Dict: The deleted instance information. | |||
| """ | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| if cookies is None: | |||
| raise NotLoginException( | |||
| 'Token does not exist, please login first.') | |||
| params = DeleteServiceParameters(provider=provider) | |||
| path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | |||
| self.endpoint, instance_name, params.to_query_str()) | |||
| r = requests.delete(path, cookies=cookies) | |||
| handle_http_response(r, logger, cookies, 'delete_deployed_model') | |||
| if r.status_code == HTTPStatus.OK: | |||
| if is_ok(r.json()): | |||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||
| return data | |||
| else: | |||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| r.raise_for_status() | |||
| return None | |||
| def _check_cookie(self, | |||
| use_cookies: Union[bool, | |||
| CookieJar] = False) -> CookieJar: | |||
| @@ -1,11 +1,25 @@ | |||
| import urllib | |||
| from abc import ABC, abstractmethod | |||
| from typing import Optional, Union | |||
| from abc import ABC | |||
| from http import HTTPStatus | |||
| from typing import Optional | |||
| import attrs | |||
| import json | |||
| from attr import fields | |||
| import requests | |||
| from attrs import asdict, define, field, validators | |||
| from modelscope.hub.api import ModelScopeConfig | |||
| from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||
| API_RESPONSE_FIELD_MESSAGE) | |||
| from modelscope.hub.errors import (NotLoginException, NotSupportError, | |||
| RequestError, handle_http_response, is_ok) | |||
| from modelscope.hub.utils.utils import get_endpoint | |||
| from modelscope.utils.logger import get_logger | |||
| # yapf: enable | |||
| logger = get_logger() | |||
| class Accelerator(object): | |||
| CPU = 'cpu' | |||
| @@ -76,12 +90,12 @@ class ServiceResourceConfig(object): | |||
| @define | |||
| class ServiceParameters(ABC): | |||
| class ServiceProviderParameters(ABC): | |||
| pass | |||
| @define | |||
| class EASDeployParameters(ServiceParameters): | |||
| class EASDeployParameters(ServiceProviderParameters): | |||
| """Parameters for EAS Deployment. | |||
| Args: | |||
| @@ -97,29 +111,10 @@ class EASDeployParameters(ServiceParameters): | |||
| resource_group: Optional[str] = None | |||
| vendor: str = field( | |||
| default=Vendor.EAS, validator=validators.in_([Vendor.EAS])) | |||
| """ | |||
| def __init__(self, | |||
| instance_name: str, | |||
| access_key_id: str, | |||
| access_key_secret: str, | |||
| region = EASRegion.beijing, | |||
| instance_type: str = EASCpuInstances.small, | |||
| accelerator: str = Accelerator.CPU, | |||
| resource_group: Optional[str] = None, | |||
| scaling: Optional[str] = None): | |||
| self.instance_name=instance_name | |||
| self.access_key_id=self.access_key_id | |||
| self.access_key_secret = access_key_secret | |||
| self.region = region | |||
| self.instance_type = instance_type | |||
| self.accelerator = accelerator | |||
| self.resource_group = resource_group | |||
| self.scaling = scaling | |||
| """ | |||
| @define | |||
| class EASListParameters(ServiceParameters): | |||
| class EASListParameters(ServiceProviderParameters): | |||
| """EAS instance list parameters. | |||
| Args: | |||
| @@ -152,7 +147,7 @@ class DeployServiceParameters(object): | |||
| model_id: str | |||
| revision: str | |||
| resource: ServiceResourceConfig | |||
| provider: ServiceParameters | |||
| provider: ServiceProviderParameters | |||
| class AttrsToQueryString(ABC): | |||
| @@ -174,16 +169,173 @@ class AttrsToQueryString(ABC): | |||
| @define | |||
| class ListServiceParameters(AttrsToQueryString): | |||
| provider: ServiceParameters | |||
| provider: ServiceProviderParameters | |||
| skip: int = 0 | |||
| limit: int = 100 | |||
| @define | |||
| class GetServiceParameters(AttrsToQueryString): | |||
| provider: ServiceParameters | |||
| provider: ServiceProviderParameters | |||
| @define | |||
| class DeleteServiceParameters(AttrsToQueryString): | |||
| provider: ServiceParameters | |||
| provider: ServiceProviderParameters | |||
| class ServiceDeployer(object): | |||
| def __init__(self, endpoint=None): | |||
| self.endpoint = endpoint if endpoint is not None else get_endpoint() | |||
| self.cookies = ModelScopeConfig.get_cookies() | |||
| if self.cookies is None: | |||
| raise NotLoginException( | |||
| 'Token does not exist, please login with HubApi first.') | |||
| # deploy_model | |||
| def create(self, model_id: str, revision: str, instance_name: str, | |||
| resource: ServiceResourceConfig, | |||
| provider: ServiceProviderParameters): | |||
| """Deploy model to cloud, current we only support PAI EAS, this is an async API , | |||
| and the deployment could take a while to finish remotely. Please check deploy instance | |||
| status separately via checking the status. | |||
| Args: | |||
| model_id (str): The deployed model id | |||
| revision (str): The model revision | |||
| instance_name (str): The deployed model instance name. | |||
| resource (ServiceResourceConfig): The service resource information. | |||
| provider (ServiceProviderParameters): The service provider parameter | |||
| Raises: | |||
| NotLoginException: To use this api, you need login first. | |||
| NotSupportError: Not supported platform. | |||
| RequestError: The server return error. | |||
| Returns: | |||
| ServiceInstanceInfo: The information of the deployed service instance. | |||
| """ | |||
| if provider.vendor != Vendor.EAS: | |||
| raise NotSupportError( | |||
| 'Not support vendor: %s ,only support EAS current.' % | |||
| (provider.vendor)) | |||
| create_params = DeployServiceParameters( | |||
| instance_name=instance_name, | |||
| model_id=model_id, | |||
| revision=revision, | |||
| resource=resource, | |||
| provider=provider) | |||
| path = f'{self.endpoint}/api/v1/deployer/endpoint' | |||
| body = attrs.asdict(create_params) | |||
| r = requests.post( | |||
| path, | |||
| json=body, | |||
| cookies=self.cookies, | |||
| ) | |||
| handle_http_response(r, logger, self.cookies, 'create_service') | |||
| if r.status_code >= HTTPStatus.OK and r.status_code < HTTPStatus.MULTIPLE_CHOICES: | |||
| if is_ok(r.json()): | |||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||
| return data | |||
| else: | |||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| r.raise_for_status() | |||
| return None | |||
| def get(self, instance_name: str, provider: ServiceProviderParameters): | |||
| """Query the specified instance information. | |||
| Args: | |||
| instance_name (str): The deployed instance name. | |||
| provider (ServiceProviderParameters): The cloud provider information, for eas | |||
| need region(eg: ch-hangzhou), access_key_id and access_key_secret. | |||
| Raises: | |||
| NotLoginException: To use this api, you need login first. | |||
| RequestError: The request is failed from server. | |||
| Returns: | |||
| Dict: The information of the requested service instance. | |||
| """ | |||
| params = GetServiceParameters(provider=provider) | |||
| path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | |||
| self.endpoint, instance_name, params.to_query_str()) | |||
| r = requests.get(path, cookies=self.cookies) | |||
| handle_http_response(r, logger, self.cookies, 'get_service') | |||
| if r.status_code == HTTPStatus.OK: | |||
| if is_ok(r.json()): | |||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||
| return data | |||
| else: | |||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| r.raise_for_status() | |||
| return None | |||
| def delete(self, instance_name: str, provider: ServiceProviderParameters): | |||
| """Delete deployed model, this api send delete command and return, it will take | |||
| some to delete, please check through the cloud console. | |||
| Args: | |||
| instance_name (str): The instance name you want to delete. | |||
| provider (ServiceProviderParameters): The cloud provider information, for eas | |||
| need region(eg: ch-hangzhou), access_key_id and access_key_secret. | |||
| Raises: | |||
| NotLoginException: To call this api, you need login first. | |||
| RequestError: The request is failed. | |||
| Returns: | |||
| Dict: The deleted instance information. | |||
| """ | |||
| params = DeleteServiceParameters(provider=provider) | |||
| path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | |||
| self.endpoint, instance_name, params.to_query_str()) | |||
| r = requests.delete(path, cookies=self.cookies) | |||
| handle_http_response(r, logger, self.cookies, 'delete_service') | |||
| if r.status_code == HTTPStatus.OK: | |||
| if is_ok(r.json()): | |||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||
| return data | |||
| else: | |||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| r.raise_for_status() | |||
| return None | |||
| def list(self, | |||
| provider: ServiceProviderParameters, | |||
| skip: int = 0, | |||
| limit: int = 100): | |||
| """List deployed model instances. | |||
| Args: | |||
| provider (ServiceProviderParameters): The cloud service provider parameter, | |||
| for eas, need access_key_id and access_key_secret. | |||
| skip: start of the list, current not support. | |||
| limit: maximum number of instances return, current not support | |||
| Raises: | |||
| NotLoginException: To use this api, you need login first. | |||
| RequestError: The request is failed from server. | |||
| Returns: | |||
| List: List of instance information | |||
| """ | |||
| params = ListServiceParameters( | |||
| provider=provider, skip=skip, limit=limit) | |||
| path = '%s/api/v1/deployer/endpoint?%s' % (self.endpoint, | |||
| params.to_query_str()) | |||
| r = requests.get(path, cookies=self.cookies) | |||
| handle_http_response(r, logger, self.cookies, 'list_service_instances') | |||
| if r.status_code == HTTPStatus.OK: | |||
| if is_ok(r.json()): | |||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||
| return data | |||
| else: | |||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| r.raise_for_status() | |||
| return None | |||
| @@ -49,7 +49,7 @@ class FaceRecognitionPipeline(Pipeline): | |||
| # face detect pipeline | |||
| det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps' | |||
| self.face_detection = pipeline( | |||
| Tasks.face_detection, model=det_model_id, model_revision='v2') | |||
| Tasks.face_detection, model=det_model_id) | |||
| def _choose_face(self, | |||
| det_result, | |||
| @@ -17,6 +17,9 @@ from modelscope.preprocessors.space_T_cn.fields.database import Database | |||
| from modelscope.preprocessors.space_T_cn.fields.struct import (Constant, | |||
| SQLQuery) | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| __all__ = ['TableQuestionAnsweringPipeline'] | |||
| @@ -309,7 +312,8 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||
| 'header_name': header_names, | |||
| 'rows': rows | |||
| } | |||
| except Exception: | |||
| except Exception as e: | |||
| logger.error(e) | |||
| tabledata = {'header_id': [], 'header_name': [], 'rows': []} | |||
| else: | |||
| tabledata = {'header_id': [], 'header_name': [], 'rows': []} | |||
| @@ -17,7 +17,8 @@ class Database: | |||
| self.tokenizer = tokenizer | |||
| self.is_use_sqlite = is_use_sqlite | |||
| if self.is_use_sqlite: | |||
| self.connection_obj = sqlite3.connect(':memory:') | |||
| self.connection_obj = sqlite3.connect( | |||
| ':memory:', check_same_thread=False) | |||
| self.type_dict = {'text': 'TEXT', 'number': 'INT', 'date': 'TEXT'} | |||
| self.tables = self.init_tables(table_file_path=table_file_path) | |||
| self.syn_dict = self.init_syn_dict( | |||
| @@ -28,8 +28,7 @@ class FaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| input_location = ['data/test/images/face_detection2.jpeg'] | |||
| dataset = MsDataset.load(input_location, target='image') | |||
| face_detection = pipeline( | |||
| Tasks.face_detection, model=self.model_id, model_revision='v2') | |||
| face_detection = pipeline(Tasks.face_detection, model=self.model_id) | |||
| # note that for dataset output, the inference-output is a Generator that can be iterated. | |||
| result = face_detection(dataset) | |||
| result = next(result) | |||
| @@ -37,8 +36,7 @@ class FaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_modelhub(self): | |||
| face_detection = pipeline( | |||
| Tasks.face_detection, model=self.model_id, model_revision='v2') | |||
| face_detection = pipeline(Tasks.face_detection, model=self.model_id) | |||
| img_path = 'data/test/images/face_detection2.jpeg' | |||
| result = face_detection(img_path) | |||
| @@ -1,6 +1,7 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import unittest | |||
| from threading import Thread | |||
| from typing import List | |||
| import json | |||
| @@ -108,8 +109,6 @@ class TableQuestionAnswering(unittest.TestCase): | |||
| self.task = Tasks.table_question_answering | |||
| self.model_id = 'damo/nlp_convai_text2sql_pretrain_cn' | |||
| model_id = 'damo/nlp_convai_text2sql_pretrain_cn' | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_by_direct_model_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| @@ -122,6 +121,27 @@ class TableQuestionAnswering(unittest.TestCase): | |||
| ] | |||
| tableqa_tracking_and_print_results_with_history(pipelines) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_by_direct_model_download_with_multithreads(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| pl = pipeline(Tasks.table_question_answering, model=cache_path) | |||
| def print_func(pl, i): | |||
| result = pl({ | |||
| 'question': '长江流域的小(2)型水库的库容总量是多少?', | |||
| 'table_id': 'reservoir', | |||
| 'history_sql': None | |||
| }) | |||
| print(i, json.dumps(result)) | |||
| procs = [] | |||
| for i in range(5): | |||
| proc = Thread(target=print_func, args=(pl, i)) | |||
| procs.append(proc) | |||
| proc.start() | |||
| for proc in procs: | |||
| proc.join() | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| @@ -28,7 +28,7 @@ def _setup(): | |||
| val_root = val_dir + '/' + os.listdir(val_dir)[0] + '/' | |||
| max_epochs = 1 # run epochs in unit test | |||
| cache_path = snapshot_download(model_id, revision='v2') | |||
| cache_path = snapshot_download(model_id) | |||
| tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(tmp_dir): | |||
| @@ -34,14 +34,14 @@ class ImageDenoiseTrainerTest(unittest.TestCase): | |||
| 'SIDD', | |||
| namespace='huizheng', | |||
| subset_name='default', | |||
| split='validation', | |||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds | |||
| split='test', | |||
| download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds | |||
| dataset_val = MsDataset.load( | |||
| 'SIDD', | |||
| namespace='huizheng', | |||
| subset_name='default', | |||
| split='test', | |||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds | |||
| download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds | |||
| self.dataset_train = SiddImageDenoisingDataset( | |||
| dataset_train, self.config.dataset, is_train=True) | |||
| self.dataset_val = SiddImageDenoisingDataset( | |||
| @@ -51,7 +51,7 @@ class ImageDenoiseTrainerTest(unittest.TestCase): | |||
| shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||
| super().tearDown() | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_trainer(self): | |||
| kwargs = dict( | |||
| model=self.model_id, | |||
| @@ -65,7 +65,7 @@ class ImageDenoiseTrainerTest(unittest.TestCase): | |||
| for i in range(2): | |||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_trainer_with_model_and_args(self): | |||
| model = NAFNetForImageDenoise.from_pretrained(self.cache_path) | |||
| kwargs = dict( | |||