| @@ -12,6 +12,7 @@ from http.cookiejar import CookieJar | |||||
| from os.path import expanduser | from os.path import expanduser | ||||
| from typing import List, Optional, Tuple, Union | from typing import List, Optional, Tuple, Union | ||||
| import attrs | |||||
| import requests | import requests | ||||
| from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | ||||
| @@ -21,9 +22,14 @@ from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||||
| API_RESPONSE_FIELD_USERNAME, | API_RESPONSE_FIELD_USERNAME, | ||||
| DEFAULT_CREDENTIALS_PATH, Licenses, | DEFAULT_CREDENTIALS_PATH, Licenses, | ||||
| ModelVisibility) | ModelVisibility) | ||||
| from modelscope.hub.deploy import (DeleteServiceParameters, | |||||
| DeployServiceParameters, | |||||
| GetServiceParameters, ListServiceParameters, | |||||
| ServiceParameters, ServiceResourceConfig, | |||||
| Vendor) | |||||
| from modelscope.hub.errors import (InvalidParameter, NotExistError, | from modelscope.hub.errors import (InvalidParameter, NotExistError, | ||||
| NotLoginException, RequestError, | |||||
| datahub_raise_on_error, | |||||
| NotLoginException, NotSupportError, | |||||
| RequestError, datahub_raise_on_error, | |||||
| handle_http_post_error, | handle_http_post_error, | ||||
| handle_http_response, is_ok, raise_on_error) | handle_http_response, is_ok, raise_on_error) | ||||
| from modelscope.hub.git import GitCommandWrapper | from modelscope.hub.git import GitCommandWrapper | ||||
| @@ -306,6 +312,169 @@ class HubApi: | |||||
| r.raise_for_status() | r.raise_for_status() | ||||
| return None | return None | ||||
| def deploy_model(self, model_id: str, revision: str, instance_name: str, | |||||
| resource: ServiceResourceConfig, | |||||
| provider: ServiceParameters): | |||||
| """Deploy model to cloud, current we only support PAI EAS, this is asynchronous | |||||
| call , please check instance status through the console or query the instance status. | |||||
| At the same time, this call may take a long time. | |||||
| Args: | |||||
| model_id (str): The deployed model id | |||||
| revision (str): The model revision | |||||
| instance_name (str): The deployed model instance name. | |||||
| resource (DeployResource): The resource information. | |||||
| provider (CreateParameter): The cloud service provider parameter | |||||
| Raises: | |||||
| NotLoginException: To use this api, you need login first. | |||||
| NotSupportError: Not supported platform. | |||||
| RequestError: The server return error. | |||||
| Returns: | |||||
| InstanceInfo: The instance information. | |||||
| """ | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| if cookies is None: | |||||
| raise NotLoginException( | |||||
| 'Token does not exist, please login first.') | |||||
| if provider.vendor != Vendor.EAS: | |||||
| raise NotSupportError( | |||||
| 'Not support vendor: %s ,only support EAS current.' % | |||||
| (provider.vendor)) | |||||
| create_params = DeployServiceParameters( | |||||
| instance_name=instance_name, | |||||
| model_id=model_id, | |||||
| revision=revision, | |||||
| resource=resource, | |||||
| provider=provider) | |||||
| path = f'{self.endpoint}/api/v1/deployer/endpoint' | |||||
| body = attrs.asdict(create_params) | |||||
| r = requests.post( | |||||
| path, | |||||
| json=body, | |||||
| cookies=cookies, | |||||
| ) | |||||
| handle_http_response(r, logger, cookies, 'create_eas_instance') | |||||
| if r.status_code >= HTTPStatus.OK and r.status_code < HTTPStatus.MULTIPLE_CHOICES: | |||||
| if is_ok(r.json()): | |||||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||||
| return data | |||||
| else: | |||||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||||
| else: | |||||
| r.raise_for_status() | |||||
| return None | |||||
| def list_deployed_model_instances(self, | |||||
| provider: ServiceParameters, | |||||
| skip: int = 0, | |||||
| limit: int = 100): | |||||
| """List deployed model instances. | |||||
| Args: | |||||
| provider (ListServiceParameter): The cloud service provider parameter, | |||||
| for eas, need access_key_id and access_key_secret. | |||||
| skip: start of the list, current not support. | |||||
| limit: maximum number of instances return, current not support | |||||
| Raises: | |||||
| NotLoginException: To use this api, you need login first. | |||||
| RequestError: The request is failed from server. | |||||
| Returns: | |||||
| List: List of instance information | |||||
| """ | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| if cookies is None: | |||||
| raise NotLoginException( | |||||
| 'Token does not exist, please login first.') | |||||
| params = ListServiceParameters( | |||||
| provider=provider, skip=skip, limit=limit) | |||||
| path = '%s/api/v1/deployer/endpoint?%s' % (self.endpoint, | |||||
| params.to_query_str()) | |||||
| r = requests.get(path, cookies=cookies) | |||||
| handle_http_response(r, logger, cookies, 'list_deployed_model') | |||||
| if r.status_code == HTTPStatus.OK: | |||||
| if is_ok(r.json()): | |||||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||||
| return data | |||||
| else: | |||||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||||
| else: | |||||
| r.raise_for_status() | |||||
| return None | |||||
| def get_deployed_model_instance(self, instance_name: str, | |||||
| provider: ServiceParameters): | |||||
| """Query the specified instance information. | |||||
| Args: | |||||
| instance_name (str): The deployed instance name. | |||||
| provider (GetParameter): The cloud provider information, for eas | |||||
| need region(eg: ch-hangzhou), access_key_id and access_key_secret. | |||||
| Raises: | |||||
| NotLoginException: To use this api, you need login first. | |||||
| RequestError: The request is failed from server. | |||||
| Returns: | |||||
| Dict: The request instance information | |||||
| """ | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| if cookies is None: | |||||
| raise NotLoginException( | |||||
| 'Token does not exist, please login first.') | |||||
| params = GetServiceParameters(provider=provider) | |||||
| path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | |||||
| self.endpoint, instance_name, params.to_query_str()) | |||||
| r = requests.get(path, cookies=cookies) | |||||
| handle_http_response(r, logger, cookies, 'get_deployed_model') | |||||
| if r.status_code == HTTPStatus.OK: | |||||
| if is_ok(r.json()): | |||||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||||
| return data | |||||
| else: | |||||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||||
| else: | |||||
| r.raise_for_status() | |||||
| return None | |||||
| def delete_deployed_model_instance(self, instance_name: str, | |||||
| provider: ServiceParameters): | |||||
| """Delete deployed model, this api send delete command and return, it will take | |||||
| some to delete, please check through the cloud console. | |||||
| Args: | |||||
| instance_name (str): The instance name you want to delete. | |||||
| provider (DeleteParameter): The cloud provider information, for eas | |||||
| need region(eg: ch-hangzhou), access_key_id and access_key_secret. | |||||
| Raises: | |||||
| NotLoginException: To call this api, you need login first. | |||||
| RequestError: The request is failed. | |||||
| Returns: | |||||
| Dict: The deleted instance information. | |||||
| """ | |||||
| cookies = ModelScopeConfig.get_cookies() | |||||
| if cookies is None: | |||||
| raise NotLoginException( | |||||
| 'Token does not exist, please login first.') | |||||
| params = DeleteServiceParameters(provider=provider) | |||||
| path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | |||||
| self.endpoint, instance_name, params.to_query_str()) | |||||
| r = requests.delete(path, cookies=cookies) | |||||
| handle_http_response(r, logger, cookies, 'delete_deployed_model') | |||||
| if r.status_code == HTTPStatus.OK: | |||||
| if is_ok(r.json()): | |||||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||||
| return data | |||||
| else: | |||||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||||
| else: | |||||
| r.raise_for_status() | |||||
| return None | |||||
| def _check_cookie(self, | def _check_cookie(self, | ||||
| use_cookies: Union[bool, | use_cookies: Union[bool, | ||||
| CookieJar] = False) -> CookieJar: | CookieJar] = False) -> CookieJar: | ||||
| @@ -0,0 +1,189 @@ | |||||
| import urllib | |||||
| from abc import ABC, abstractmethod | |||||
| from typing import Optional, Union | |||||
| import json | |||||
| from attr import fields | |||||
| from attrs import asdict, define, field, validators | |||||
| class Accelerator(object): | |||||
| CPU = 'cpu' | |||||
| GPU = 'gpu' | |||||
| class Vendor(object): | |||||
| EAS = 'eas' | |||||
| class EASRegion(object): | |||||
| beijing = 'cn-beijing' | |||||
| hangzhou = 'cn-hangzhou' | |||||
| class EASCpuInstanceType(object): | |||||
| """EAS Cpu Instance TYpe, ref(https://help.aliyun.com/document_detail/144261.html) | |||||
| """ | |||||
| tiny = 'ecs.c6.2xlarge' | |||||
| small = 'ecs.c6.4xlarge' | |||||
| medium = 'ecs.c6.6xlarge' | |||||
| large = 'ecs.c6.8xlarge' | |||||
| class EASGpuInstanceType(object): | |||||
| """EAS Cpu Instance TYpe, ref(https://help.aliyun.com/document_detail/144261.html) | |||||
| """ | |||||
| tiny = 'ecs.gn5-c28g1.7xlarge' | |||||
| small = 'ecs.gn5-c8g1.4xlarge' | |||||
| medium = 'ecs.gn6i-c24g1.12xlarge' | |||||
| large = 'ecs.gn6e-c12g1.3xlarge' | |||||
| def min_smaller_than_max(instance, attribute, value): | |||||
| if value > instance.max_replica: | |||||
| raise ValueError( | |||||
| "'min_replica' value: %s has to be smaller than 'max_replica' value: %s!" | |||||
| % (value, instance.max_replica)) | |||||
| @define | |||||
| class ServiceScalingConfig(object): | |||||
| """Resource scaling config | |||||
| Currently we ignore max_replica | |||||
| Args: | |||||
| max_replica: maximum replica | |||||
| min_replica: minimum replica | |||||
| """ | |||||
| max_replica: int = field(default=1, validator=validators.ge(1)) | |||||
| min_replica: int = field( | |||||
| default=1, validator=[validators.ge(1), min_smaller_than_max]) | |||||
| @define | |||||
| class ServiceResourceConfig(object): | |||||
| """Eas Resource request. | |||||
| Args: | |||||
| accelerator: the accelerator(cpu|gpu) | |||||
| instance_type: the instance type. | |||||
| scaling: The instance scaling config. | |||||
| """ | |||||
| instance_type: str | |||||
| scaling: ServiceScalingConfig | |||||
| accelerator: str = field( | |||||
| default=Accelerator.CPU, | |||||
| validator=validators.in_([Accelerator.CPU, Accelerator.GPU])) | |||||
| @define | |||||
| class ServiceParameters(ABC): | |||||
| pass | |||||
| @define | |||||
| class EASDeployParameters(ServiceParameters): | |||||
| """Parameters for EAS Deployment. | |||||
| Args: | |||||
| resource_group: the resource group to deploy, current default. | |||||
| region: The eas instance region(eg: cn-hangzhou). | |||||
| access_key_id: The eas account access key id. | |||||
| access_key_secret: The eas account access key secret. | |||||
| vendor: must be 'eas' | |||||
| """ | |||||
| region: str | |||||
| access_key_id: str | |||||
| access_key_secret: str | |||||
| resource_group: Optional[str] = None | |||||
| vendor: str = field( | |||||
| default=Vendor.EAS, validator=validators.in_([Vendor.EAS])) | |||||
| """ | |||||
| def __init__(self, | |||||
| instance_name: str, | |||||
| access_key_id: str, | |||||
| access_key_secret: str, | |||||
| region = EASRegion.beijing, | |||||
| instance_type: str = EASCpuInstances.small, | |||||
| accelerator: str = Accelerator.CPU, | |||||
| resource_group: Optional[str] = None, | |||||
| scaling: Optional[str] = None): | |||||
| self.instance_name=instance_name | |||||
| self.access_key_id=self.access_key_id | |||||
| self.access_key_secret = access_key_secret | |||||
| self.region = region | |||||
| self.instance_type = instance_type | |||||
| self.accelerator = accelerator | |||||
| self.resource_group = resource_group | |||||
| self.scaling = scaling | |||||
| """ | |||||
| @define | |||||
| class EASListParameters(ServiceParameters): | |||||
| """EAS instance list parameters. | |||||
| Args: | |||||
| resource_group: the resource group to deploy, current default. | |||||
| region: The eas instance region(eg: cn-hangzhou). | |||||
| access_key_id: The eas account access key id. | |||||
| access_key_secret: The eas account access key secret. | |||||
| vendor: must be 'eas' | |||||
| """ | |||||
| access_key_id: str | |||||
| access_key_secret: str | |||||
| region: str = None | |||||
| resource_group: str = None | |||||
| vendor: str = field( | |||||
| default=Vendor.EAS, validator=validators.in_([Vendor.EAS])) | |||||
| @define | |||||
| class DeployServiceParameters(object): | |||||
| """Deploy service parameters | |||||
| Args: | |||||
| instance_name: the name of the service. | |||||
| model_id: the modelscope model_id | |||||
| revision: the modelscope model revision | |||||
| resource: the resource requirement. | |||||
| provider: the cloud service provider. | |||||
| """ | |||||
| instance_name: str | |||||
| model_id: str | |||||
| revision: str | |||||
| resource: ServiceResourceConfig | |||||
| provider: ServiceParameters | |||||
| class AttrsToQueryString(ABC): | |||||
| """Convert the attrs class to json string. | |||||
| Args: | |||||
| """ | |||||
| def to_query_str(self): | |||||
| self_dict = asdict( | |||||
| self.provider, filter=lambda attr, value: value is not None) | |||||
| json_str = json.dumps(self_dict) | |||||
| print(json_str) | |||||
| safe_str = urllib.parse.quote_plus(json_str) | |||||
| print(safe_str) | |||||
| query_param = 'provider=%s' % safe_str | |||||
| return query_param | |||||
| @define | |||||
| class ListServiceParameters(AttrsToQueryString): | |||||
| provider: ServiceParameters | |||||
| skip: int = 0 | |||||
| limit: int = 100 | |||||
| @define | |||||
| class GetServiceParameters(AttrsToQueryString): | |||||
| provider: ServiceParameters | |||||
| @define | |||||
| class DeleteServiceParameters(AttrsToQueryString): | |||||
| provider: ServiceParameters | |||||
| @@ -9,6 +9,10 @@ from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | logger = get_logger() | ||||
| class NotSupportError(Exception): | |||||
| pass | |||||
| class NotExistError(Exception): | class NotExistError(Exception): | ||||
| pass | pass | ||||
| @@ -66,6 +70,7 @@ def handle_http_response(response, logger, cookies, model_id): | |||||
| logger.error( | logger.error( | ||||
| f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \ | f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \ | ||||
| private. Please login first.') | private. Please login first.') | ||||
| logger.error('Response details: %s' % response.content) | |||||
| raise error | raise error | ||||
| @@ -67,8 +67,9 @@ class Models(object): | |||||
| space_dst = 'space-dst' | space_dst = 'space-dst' | ||||
| space_intent = 'space-intent' | space_intent = 'space-intent' | ||||
| space_modeling = 'space-modeling' | space_modeling = 'space-modeling' | ||||
| star = 'star' | |||||
| star3 = 'star3' | |||||
| space_T_en = 'space-T-en' | |||||
| space_T_cn = 'space-T-cn' | |||||
| tcrf = 'transformer-crf' | tcrf = 'transformer-crf' | ||||
| transformer_softmax = 'transformer-softmax' | transformer_softmax = 'transformer-softmax' | ||||
| lcrf = 'lstm-crf' | lcrf = 'lstm-crf' | ||||
| @@ -16,6 +16,7 @@ from modelscope.models.builder import MODELS | |||||
| from modelscope.preprocessors import LoadImage | from modelscope.preprocessors import LoadImage | ||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| from .utils import timestamp_format | |||||
| from .yolox.data.data_augment import ValTransform | from .yolox.data.data_augment import ValTransform | ||||
| from .yolox.exp import get_exp_by_name | from .yolox.exp import get_exp_by_name | ||||
| from .yolox.utils import postprocess | from .yolox.utils import postprocess | ||||
| @@ -99,14 +100,17 @@ class RealtimeVideoDetector(TorchModel): | |||||
| def inference_video(self, v_path): | def inference_video(self, v_path): | ||||
| outputs = [] | outputs = [] | ||||
| desc = 'Detecting video: {}'.format(v_path) | desc = 'Detecting video: {}'.format(v_path) | ||||
| for frame, result in tqdm( | |||||
| self.inference_video_iter(v_path), desc=desc): | |||||
| for frame_idx, (frame, result) in enumerate( | |||||
| tqdm(self.inference_video_iter(v_path), desc=desc)): | |||||
| result = result + (timestamp_format(seconds=frame_idx | |||||
| / self.fps), ) | |||||
| outputs.append(result) | outputs.append(result) | ||||
| return outputs | return outputs | ||||
| def inference_video_iter(self, v_path): | def inference_video_iter(self, v_path): | ||||
| capture = cv2.VideoCapture(v_path) | capture = cv2.VideoCapture(v_path) | ||||
| self.fps = capture.get(cv2.CAP_PROP_FPS) | |||||
| while capture.isOpened(): | while capture.isOpened(): | ||||
| ret, frame = capture.read() | ret, frame = capture.read() | ||||
| if not ret: | if not ret: | ||||
| @@ -0,0 +1,9 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import math | |||||
| def timestamp_format(seconds): | |||||
| m, s = divmod(seconds, 60) | |||||
| h, m = divmod(m, 60) | |||||
| time = '%02d:%02d:%06.3f' % (h, m, s) | |||||
| return time | |||||
| @@ -24,8 +24,8 @@ import json | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| class Star3Config(object): | |||||
| """Configuration class to store the configuration of a `Star3Model`. | |||||
| class SpaceTCnConfig(object): | |||||
| """Configuration class to store the configuration of a `SpaceTCnModel`. | |||||
| """ | """ | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -40,10 +40,10 @@ class Star3Config(object): | |||||
| max_position_embeddings=512, | max_position_embeddings=512, | ||||
| type_vocab_size=2, | type_vocab_size=2, | ||||
| initializer_range=0.02): | initializer_range=0.02): | ||||
| """Constructs Star3Config. | |||||
| """Constructs SpaceTCnConfig. | |||||
| Args: | Args: | ||||
| vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `Star3Model`. | |||||
| vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `SpaceTCnConfig`. | |||||
| hidden_size: Size of the encoder layers and the pooler layer. | hidden_size: Size of the encoder layers and the pooler layer. | ||||
| num_hidden_layers: Number of hidden layers in the Transformer encoder. | num_hidden_layers: Number of hidden layers in the Transformer encoder. | ||||
| num_attention_heads: Number of attention heads for each attention layer in | num_attention_heads: Number of attention heads for each attention layer in | ||||
| @@ -59,7 +59,7 @@ class Star3Config(object): | |||||
| max_position_embeddings: The maximum sequence length that this model might | max_position_embeddings: The maximum sequence length that this model might | ||||
| ever be used with. Typically set this to something large just in case | ever be used with. Typically set this to something large just in case | ||||
| (e.g., 512 or 1024 or 2048). | (e.g., 512 or 1024 or 2048). | ||||
| type_vocab_size: The vocabulary size of the `token_type_ids` passed into `Star3Model`. | |||||
| type_vocab_size: The vocabulary size of the `token_type_ids` passed into `SpaceTCnConfig`. | |||||
| initializer_range: The sttdev of the truncated_normal_initializer for | initializer_range: The sttdev of the truncated_normal_initializer for | ||||
| initializing all weight matrices. | initializing all weight matrices. | ||||
| """ | """ | ||||
| @@ -89,15 +89,15 @@ class Star3Config(object): | |||||
| @classmethod | @classmethod | ||||
| def from_dict(cls, json_object): | def from_dict(cls, json_object): | ||||
| """Constructs a `Star3Config` from a Python dictionary of parameters.""" | |||||
| config = Star3Config(vocab_size_or_config_json_file=-1) | |||||
| """Constructs a `SpaceTCnConfig` from a Python dictionary of parameters.""" | |||||
| config = SpaceTCnConfig(vocab_size_or_config_json_file=-1) | |||||
| for key, value in json_object.items(): | for key, value in json_object.items(): | ||||
| config.__dict__[key] = value | config.__dict__[key] = value | ||||
| return config | return config | ||||
| @classmethod | @classmethod | ||||
| def from_json_file(cls, json_file): | def from_json_file(cls, json_file): | ||||
| """Constructs a `Star3Config` from a json file of parameters.""" | |||||
| """Constructs a `SpaceTCnConfig` from a json file of parameters.""" | |||||
| with open(json_file, 'r', encoding='utf-8') as reader: | with open(json_file, 'r', encoding='utf-8') as reader: | ||||
| text = reader.read() | text = reader.read() | ||||
| return cls.from_dict(json.loads(text)) | return cls.from_dict(json.loads(text)) | ||||
| @@ -27,7 +27,8 @@ import numpy as np | |||||
| import torch | import torch | ||||
| from torch import nn | from torch import nn | ||||
| from modelscope.models.nlp.star3.configuration_star3 import Star3Config | |||||
| from modelscope.models.nlp.space_T_cn.configuration_space_T_cn import \ | |||||
| SpaceTCnConfig | |||||
| from modelscope.utils.constant import ModelFile | from modelscope.utils.constant import ModelFile | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| @@ -609,9 +610,9 @@ class PreTrainedBertModel(nn.Module): | |||||
| def __init__(self, config, *inputs, **kwargs): | def __init__(self, config, *inputs, **kwargs): | ||||
| super(PreTrainedBertModel, self).__init__() | super(PreTrainedBertModel, self).__init__() | ||||
| if not isinstance(config, Star3Config): | |||||
| if not isinstance(config, SpaceTCnConfig): | |||||
| raise ValueError( | raise ValueError( | ||||
| 'Parameter config in `{}(config)` should be an instance of class `Star3Config`. ' | |||||
| 'Parameter config in `{}(config)` should be an instance of class `SpaceTCnConfig`. ' | |||||
| 'To create a model from a Google pretrained model use ' | 'To create a model from a Google pretrained model use ' | ||||
| '`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`'.format( | '`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`'.format( | ||||
| self.__class__.__name__, self.__class__.__name__)) | self.__class__.__name__, self.__class__.__name__)) | ||||
| @@ -676,7 +677,7 @@ class PreTrainedBertModel(nn.Module): | |||||
| serialization_dir = tempdir | serialization_dir = tempdir | ||||
| # Load config | # Load config | ||||
| config_file = os.path.join(serialization_dir, CONFIG_NAME) | config_file = os.path.join(serialization_dir, CONFIG_NAME) | ||||
| config = Star3Config.from_json_file(config_file) | |||||
| config = SpaceTCnConfig.from_json_file(config_file) | |||||
| logger.info('Model config {}'.format(config)) | logger.info('Model config {}'.format(config)) | ||||
| # Instantiate model. | # Instantiate model. | ||||
| model = cls(config, *inputs, **kwargs) | model = cls(config, *inputs, **kwargs) | ||||
| @@ -742,11 +743,11 @@ class PreTrainedBertModel(nn.Module): | |||||
| return model | return model | ||||
| class Star3Model(PreTrainedBertModel): | |||||
| """Star3Model model ("Bidirectional Embedding Representations from a Transformer pretrained on STAR3.0"). | |||||
| class SpaceTCnModel(PreTrainedBertModel): | |||||
| """SpaceTCnModel model ("Bidirectional Embedding Representations from a Transformer pretrained on STAR-T-CN"). | |||||
| Params: | Params: | ||||
| config: a Star3Config class instance with the configuration to build a new model | |||||
| config: a SpaceTCnConfig class instance with the configuration to build a new model | |||||
| Inputs: | Inputs: | ||||
| `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] | ||||
| @@ -780,16 +781,16 @@ class Star3Model(PreTrainedBertModel): | |||||
| input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | ||||
| token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) | ||||
| config = modeling.Star3Config(vocab_size_or_config_json_file=32000, hidden_size=768, | |||||
| config = modeling.SpaceTCnConfig(vocab_size_or_config_json_file=32000, hidden_size=768, | |||||
| num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | ||||
| model = modeling.Star3Model(config=config) | |||||
| model = modeling.SpaceTCnModel(config=config) | |||||
| all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) | ||||
| ``` | ``` | ||||
| """ | """ | ||||
| def __init__(self, config, schema_link_module='none'): | def __init__(self, config, schema_link_module='none'): | ||||
| super(Star3Model, self).__init__(config) | |||||
| super(SpaceTCnModel, self).__init__(config) | |||||
| self.embeddings = BertEmbeddings(config) | self.embeddings = BertEmbeddings(config) | ||||
| self.encoder = BertEncoder( | self.encoder = BertEncoder( | ||||
| config, schema_link_module=schema_link_module) | config, schema_link_module=schema_link_module) | ||||
| @@ -20,7 +20,7 @@ __all__ = ['StarForTextToSql'] | |||||
| @MODELS.register_module( | @MODELS.register_module( | ||||
| Tasks.conversational_text_to_sql, module_name=Models.star) | |||||
| Tasks.table_question_answering, module_name=Models.space_T_en) | |||||
| class StarForTextToSql(Model): | class StarForTextToSql(Model): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| @@ -3,27 +3,25 @@ | |||||
| import os | import os | ||||
| from typing import Dict | from typing import Dict | ||||
| import json | |||||
| import numpy | import numpy | ||||
| import torch | import torch | ||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| import tqdm | |||||
| from transformers import BertTokenizer | from transformers import BertTokenizer | ||||
| from modelscope.metainfo import Models | from modelscope.metainfo import Models | ||||
| from modelscope.models.base import Model, Tensor | from modelscope.models.base import Model, Tensor | ||||
| from modelscope.models.builder import MODELS | from modelscope.models.builder import MODELS | ||||
| from modelscope.models.nlp.star3.configuration_star3 import Star3Config | |||||
| from modelscope.models.nlp.star3.modeling_star3 import Seq2SQL, Star3Model | |||||
| from modelscope.preprocessors.star3.fields.struct import Constant | |||||
| from modelscope.preprocessors.space_T_cn.fields.struct import Constant | |||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| from modelscope.utils.device import verify_device | from modelscope.utils.device import verify_device | ||||
| from .space_T_cn.configuration_space_T_cn import SpaceTCnConfig | |||||
| from .space_T_cn.modeling_space_T_cn import Seq2SQL, SpaceTCnModel | |||||
| __all__ = ['TableQuestionAnswering'] | __all__ = ['TableQuestionAnswering'] | ||||
| @MODELS.register_module( | @MODELS.register_module( | ||||
| Tasks.table_question_answering, module_name=Models.star3) | |||||
| Tasks.table_question_answering, module_name=Models.space_T_cn) | |||||
| class TableQuestionAnswering(Model): | class TableQuestionAnswering(Model): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| @@ -43,9 +41,9 @@ class TableQuestionAnswering(Model): | |||||
| os.path.join(self.model_dir, ModelFile.TORCH_MODEL_BIN_FILE), | os.path.join(self.model_dir, ModelFile.TORCH_MODEL_BIN_FILE), | ||||
| map_location='cpu') | map_location='cpu') | ||||
| self.backbone_config = Star3Config.from_json_file( | |||||
| self.backbone_config = SpaceTCnConfig.from_json_file( | |||||
| os.path.join(self.model_dir, ModelFile.CONFIGURATION)) | os.path.join(self.model_dir, ModelFile.CONFIGURATION)) | ||||
| self.backbone_model = Star3Model( | |||||
| self.backbone_model = SpaceTCnModel( | |||||
| config=self.backbone_config, schema_link_module='rat') | config=self.backbone_config, schema_link_module='rat') | ||||
| self.backbone_model.load_state_dict(state_dict['backbone_model']) | self.backbone_model.load_state_dict(state_dict['backbone_model']) | ||||
| @@ -606,21 +606,12 @@ TASK_OUTPUTS = { | |||||
| # } | # } | ||||
| Tasks.task_oriented_conversation: [OutputKeys.OUTPUT], | Tasks.task_oriented_conversation: [OutputKeys.OUTPUT], | ||||
| # conversational text-to-sql result for single sample | |||||
| # { | |||||
| # "text": "SELECT shop.Name FROM shop." | |||||
| # } | |||||
| Tasks.conversational_text_to_sql: [OutputKeys.TEXT], | |||||
| # table-question-answering result for single sample | # table-question-answering result for single sample | ||||
| # { | # { | ||||
| # "sql": "SELECT shop.Name FROM shop." | # "sql": "SELECT shop.Name FROM shop." | ||||
| # "sql_history": {sel: 0, agg: 0, conds: [[0, 0, 'val']]} | # "sql_history": {sel: 0, agg: 0, conds: [[0, 0, 'val']]} | ||||
| # } | # } | ||||
| Tasks.table_question_answering: [ | |||||
| OutputKeys.SQL_STRING, OutputKeys.SQL_QUERY, OutputKeys.HISTORY, | |||||
| OutputKeys.QUERT_RESULT | |||||
| ], | |||||
| Tasks.table_question_answering: [OutputKeys.OUTPUT], | |||||
| # ============ audio tasks =================== | # ============ audio tasks =================== | ||||
| # asr result for single sample | # asr result for single sample | ||||
| @@ -69,9 +69,6 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/nlp_space_dialog-modeling'), | 'damo/nlp_space_dialog-modeling'), | ||||
| Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, | Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, | ||||
| 'damo/nlp_space_dialog-state-tracking'), | 'damo/nlp_space_dialog-state-tracking'), | ||||
| Tasks.conversational_text_to_sql: | |||||
| (Pipelines.conversational_text_to_sql, | |||||
| 'damo/nlp_star_conversational-text-to-sql'), | |||||
| Tasks.table_question_answering: | Tasks.table_question_answering: | ||||
| (Pipelines.table_question_answering_pipeline, | (Pipelines.table_question_answering_pipeline, | ||||
| 'damo/nlp-convai-text2sql-pretrain-cn'), | 'damo/nlp-convai-text2sql-pretrain-cn'), | ||||
| @@ -113,9 +113,8 @@ class AnimalRecognitionPipeline(Pipeline): | |||||
| label_mapping = f.readlines() | label_mapping = f.readlines() | ||||
| score = torch.max(inputs['outputs']) | score = torch.max(inputs['outputs']) | ||||
| inputs = { | inputs = { | ||||
| OutputKeys.SCORES: | |||||
| score.item(), | |||||
| OutputKeys.SCORES: [score.item()], | |||||
| OutputKeys.LABELS: | OutputKeys.LABELS: | ||||
| label_mapping[inputs['outputs'].argmax()].split('\t')[1] | |||||
| [label_mapping[inputs['outputs'].argmax()].split('\t')[1]] | |||||
| } | } | ||||
| return inputs | return inputs | ||||
| @@ -114,9 +114,8 @@ class GeneralRecognitionPipeline(Pipeline): | |||||
| label_mapping = f.readlines() | label_mapping = f.readlines() | ||||
| score = torch.max(inputs['outputs']) | score = torch.max(inputs['outputs']) | ||||
| inputs = { | inputs = { | ||||
| OutputKeys.SCORES: | |||||
| score.item(), | |||||
| OutputKeys.SCORES: [score.item()], | |||||
| OutputKeys.LABELS: | OutputKeys.LABELS: | ||||
| label_mapping[inputs['outputs'].argmax()].split('\t')[1] | |||||
| [label_mapping[inputs['outputs'].argmax()].split('\t')[1]] | |||||
| } | } | ||||
| return inputs | return inputs | ||||
| @@ -45,15 +45,17 @@ class RealtimeVideoObjectDetectionPipeline(Pipeline): | |||||
| **kwargs) -> str: | **kwargs) -> str: | ||||
| forward_output = input['forward_output'] | forward_output = input['forward_output'] | ||||
| scores, boxes, labels = [], [], [] | |||||
| scores, boxes, labels, timestamps = [], [], [], [] | |||||
| for result in forward_output: | for result in forward_output: | ||||
| box, score, label = result | |||||
| box, score, label, timestamp = result | |||||
| scores.append(score) | scores.append(score) | ||||
| boxes.append(box) | boxes.append(box) | ||||
| labels.append(label) | labels.append(label) | ||||
| timestamps.append(timestamp) | |||||
| return { | return { | ||||
| OutputKeys.BOXES: boxes, | OutputKeys.BOXES: boxes, | ||||
| OutputKeys.SCORES: scores, | OutputKeys.SCORES: scores, | ||||
| OutputKeys.LABELS: labels, | OutputKeys.LABELS: labels, | ||||
| OutputKeys.TIMESTAMPS: timestamps, | |||||
| } | } | ||||
| @@ -19,7 +19,7 @@ __all__ = ['ConversationalTextToSqlPipeline'] | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.conversational_text_to_sql, | |||||
| Tasks.table_question_answering, | |||||
| module_name=Pipelines.conversational_text_to_sql) | module_name=Pipelines.conversational_text_to_sql) | ||||
| class ConversationalTextToSqlPipeline(Pipeline): | class ConversationalTextToSqlPipeline(Pipeline): | ||||
| @@ -62,7 +62,7 @@ class ConversationalTextToSqlPipeline(Pipeline): | |||||
| Dict[str, str]: the prediction results | Dict[str, str]: the prediction results | ||||
| """ | """ | ||||
| sql = Example.evaluator.obtain_sql(inputs['predict'][0], inputs['db']) | sql = Example.evaluator.obtain_sql(inputs['predict'][0], inputs['db']) | ||||
| result = {OutputKeys.TEXT: sql} | |||||
| result = {OutputKeys.OUTPUT: {OutputKeys.TEXT: sql}} | |||||
| return result | return result | ||||
| def _collate_fn(self, data): | def _collate_fn(self, data): | ||||
| @@ -13,8 +13,9 @@ from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Pipeline | from modelscope.pipelines.base import Pipeline | ||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import TableQuestionAnsweringPreprocessor | from modelscope.preprocessors import TableQuestionAnsweringPreprocessor | ||||
| from modelscope.preprocessors.star3.fields.database import Database | |||||
| from modelscope.preprocessors.star3.fields.struct import Constant, SQLQuery | |||||
| from modelscope.preprocessors.space_T_cn.fields.database import Database | |||||
| from modelscope.preprocessors.space_T_cn.fields.struct import (Constant, | |||||
| SQLQuery) | |||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| __all__ = ['TableQuestionAnsweringPipeline'] | __all__ = ['TableQuestionAnsweringPipeline'] | ||||
| @@ -320,7 +321,7 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||||
| OutputKeys.QUERT_RESULT: tabledata, | OutputKeys.QUERT_RESULT: tabledata, | ||||
| } | } | ||||
| return output | |||||
| return {OutputKeys.OUTPUT: output} | |||||
| def _collate_fn(self, data): | def _collate_fn(self, data): | ||||
| return data | return data | ||||
| @@ -40,7 +40,7 @@ if TYPE_CHECKING: | |||||
| DialogStateTrackingPreprocessor) | DialogStateTrackingPreprocessor) | ||||
| from .video import ReadVideoData, MovieSceneSegmentationPreprocessor | from .video import ReadVideoData, MovieSceneSegmentationPreprocessor | ||||
| from .star import ConversationalTextToSqlPreprocessor | from .star import ConversationalTextToSqlPreprocessor | ||||
| from .star3 import TableQuestionAnsweringPreprocessor | |||||
| from .space_T_cn import TableQuestionAnsweringPreprocessor | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| @@ -81,7 +81,7 @@ else: | |||||
| 'DialogStateTrackingPreprocessor', 'InputFeatures' | 'DialogStateTrackingPreprocessor', 'InputFeatures' | ||||
| ], | ], | ||||
| 'star': ['ConversationalTextToSqlPreprocessor'], | 'star': ['ConversationalTextToSqlPreprocessor'], | ||||
| 'star3': ['TableQuestionAnsweringPreprocessor'], | |||||
| 'space_T_cn': ['TableQuestionAnsweringPreprocessor'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -4,7 +4,7 @@ import sqlite3 | |||||
| import json | import json | ||||
| import tqdm | import tqdm | ||||
| from modelscope.preprocessors.star3.fields.struct import Trie | |||||
| from modelscope.preprocessors.space_T_cn.fields.struct import Trie | |||||
| class Database: | class Database: | ||||
| @@ -1,7 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import re | import re | ||||
| from modelscope.preprocessors.star3.fields.struct import TypeInfo | |||||
| from modelscope.preprocessors.space_T_cn.fields.struct import TypeInfo | |||||
| class SchemaLinker: | class SchemaLinker: | ||||
| @@ -8,8 +8,8 @@ from transformers import BertTokenizer | |||||
| from modelscope.metainfo import Preprocessors | from modelscope.metainfo import Preprocessors | ||||
| from modelscope.preprocessors.base import Preprocessor | from modelscope.preprocessors.base import Preprocessor | ||||
| from modelscope.preprocessors.builder import PREPROCESSORS | from modelscope.preprocessors.builder import PREPROCESSORS | ||||
| from modelscope.preprocessors.star3.fields.database import Database | |||||
| from modelscope.preprocessors.star3.fields.schema_link import SchemaLinker | |||||
| from modelscope.preprocessors.space_T_cn.fields.database import Database | |||||
| from modelscope.preprocessors.space_T_cn.fields.schema_link import SchemaLinker | |||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import Fields, ModelFile | from modelscope.utils.constant import Fields, ModelFile | ||||
| from modelscope.utils.type_assert import type_assert | from modelscope.utils.type_assert import type_assert | ||||
| @@ -123,7 +123,6 @@ class NLPTasks(object): | |||||
| backbone = 'backbone' | backbone = 'backbone' | ||||
| text_error_correction = 'text-error-correction' | text_error_correction = 'text-error-correction' | ||||
| faq_question_answering = 'faq-question-answering' | faq_question_answering = 'faq-question-answering' | ||||
| conversational_text_to_sql = 'conversational-text-to-sql' | |||||
| information_extraction = 'information-extraction' | information_extraction = 'information-extraction' | ||||
| document_segmentation = 'document-segmentation' | document_segmentation = 'document-segmentation' | ||||
| feature_extraction = 'feature-extraction' | feature_extraction = 'feature-extraction' | ||||
| @@ -20,7 +20,7 @@ def text2sql_tracking_and_print_results( | |||||
| results = p(case) | results = p(case) | ||||
| print({'question': item}) | print({'question': item}) | ||||
| print(results) | print(results) | ||||
| last_sql = results['text'] | |||||
| last_sql = results[OutputKeys.OUTPUT][OutputKeys.TEXT] | |||||
| history.append(item) | history.append(item) | ||||
| @@ -1,4 +1,5 @@ | |||||
| addict | addict | ||||
| attrs | |||||
| datasets | datasets | ||||
| easydict | easydict | ||||
| einops | einops | ||||
| @@ -16,7 +16,7 @@ from modelscope.utils.test_utils import test_level | |||||
| class ConversationalTextToSql(unittest.TestCase, DemoCompatibilityCheck): | class ConversationalTextToSql(unittest.TestCase, DemoCompatibilityCheck): | ||||
| def setUp(self) -> None: | def setUp(self) -> None: | ||||
| self.task = Tasks.conversational_text_to_sql | |||||
| self.task = Tasks.table_question_answering | |||||
| self.model_id = 'damo/nlp_star_conversational-text-to-sql' | self.model_id = 'damo/nlp_star_conversational-text-to-sql' | ||||
| model_id = 'damo/nlp_star_conversational-text-to-sql' | model_id = 'damo/nlp_star_conversational-text-to-sql' | ||||
| @@ -66,11 +66,6 @@ class ConversationalTextToSql(unittest.TestCase, DemoCompatibilityCheck): | |||||
| pipelines = [pipeline(task=self.task, model=self.model_id)] | pipelines = [pipeline(task=self.task, model=self.model_id)] | ||||
| text2sql_tracking_and_print_results(self.test_case, pipelines) | text2sql_tracking_and_print_results(self.test_case, pipelines) | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_default_model(self): | |||||
| pipelines = [pipeline(task=self.task)] | |||||
| text2sql_tracking_and_print_results(self.test_case, pipelines) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_demo_compatibility(self): | def test_demo_compatibility(self): | ||||
| self.compatibility_check() | self.compatibility_check() | ||||
| @@ -12,7 +12,7 @@ from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | from modelscope.pipelines import pipeline | ||||
| from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline | from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline | ||||
| from modelscope.preprocessors import TableQuestionAnsweringPreprocessor | from modelscope.preprocessors import TableQuestionAnsweringPreprocessor | ||||
| from modelscope.preprocessors.star3.fields.database import Database | |||||
| from modelscope.preprocessors.space_T_cn.fields.database import Database | |||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| @@ -38,7 +38,7 @@ def tableqa_tracking_and_print_results_with_history( | |||||
| output_dict = p({ | output_dict = p({ | ||||
| 'question': question, | 'question': question, | ||||
| 'history_sql': historical_queries | 'history_sql': historical_queries | ||||
| }) | |||||
| })[OutputKeys.OUTPUT] | |||||
| print('question', question) | print('question', question) | ||||
| print('sql text:', output_dict[OutputKeys.SQL_STRING]) | print('sql text:', output_dict[OutputKeys.SQL_STRING]) | ||||
| print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | ||||
| @@ -61,7 +61,7 @@ def tableqa_tracking_and_print_results_without_history( | |||||
| } | } | ||||
| for p in pipelines: | for p in pipelines: | ||||
| for question in test_case['utterance']: | for question in test_case['utterance']: | ||||
| output_dict = p({'question': question}) | |||||
| output_dict = p({'question': question})[OutputKeys.OUTPUT] | |||||
| print('question', question) | print('question', question) | ||||
| print('sql text:', output_dict[OutputKeys.SQL_STRING]) | print('sql text:', output_dict[OutputKeys.SQL_STRING]) | ||||
| print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | ||||
| @@ -92,7 +92,7 @@ def tableqa_tracking_and_print_results_with_tableid( | |||||
| 'question': question, | 'question': question, | ||||
| 'table_id': table_id, | 'table_id': table_id, | ||||
| 'history_sql': historical_queries | 'history_sql': historical_queries | ||||
| }) | |||||
| })[OutputKeys.OUTPUT] | |||||
| print('question', question) | print('question', question) | ||||
| print('sql text:', output_dict[OutputKeys.SQL_STRING]) | print('sql text:', output_dict[OutputKeys.SQL_STRING]) | ||||
| print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | ||||
| @@ -147,11 +147,6 @@ class TableQuestionAnswering(unittest.TestCase): | |||||
| ] | ] | ||||
| tableqa_tracking_and_print_results_with_tableid(pipelines) | tableqa_tracking_and_print_results_with_tableid(pipelines) | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_from_task(self): | |||||
| pipelines = [pipeline(Tasks.table_question_answering, self.model_id)] | |||||
| tableqa_tracking_and_print_results_with_history(pipelines) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_run_with_model_from_modelhub_with_other_classes(self): | def test_run_with_model_from_modelhub_with_other_classes(self): | ||||
| model = Model.from_pretrained(self.model_id) | model = Model.from_pretrained(self.model_id) | ||||