| @@ -12,6 +12,7 @@ from http.cookiejar import CookieJar | |||
| from os.path import expanduser | |||
| from typing import List, Optional, Tuple, Union | |||
| import attrs | |||
| import requests | |||
| from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||
| @@ -21,9 +22,14 @@ from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, | |||
| API_RESPONSE_FIELD_USERNAME, | |||
| DEFAULT_CREDENTIALS_PATH, Licenses, | |||
| ModelVisibility) | |||
| from modelscope.hub.deploy import (DeleteServiceParameters, | |||
| DeployServiceParameters, | |||
| GetServiceParameters, ListServiceParameters, | |||
| ServiceParameters, ServiceResourceConfig, | |||
| Vendor) | |||
| from modelscope.hub.errors import (InvalidParameter, NotExistError, | |||
| NotLoginException, RequestError, | |||
| datahub_raise_on_error, | |||
| NotLoginException, NotSupportError, | |||
| RequestError, datahub_raise_on_error, | |||
| handle_http_post_error, | |||
| handle_http_response, is_ok, raise_on_error) | |||
| from modelscope.hub.git import GitCommandWrapper | |||
| @@ -306,6 +312,169 @@ class HubApi: | |||
| r.raise_for_status() | |||
| return None | |||
| def deploy_model(self, model_id: str, revision: str, instance_name: str, | |||
| resource: ServiceResourceConfig, | |||
| provider: ServiceParameters): | |||
| """Deploy model to cloud, current we only support PAI EAS, this is asynchronous | |||
| call , please check instance status through the console or query the instance status. | |||
| At the same time, this call may take a long time. | |||
| Args: | |||
| model_id (str): The deployed model id | |||
| revision (str): The model revision | |||
| instance_name (str): The deployed model instance name. | |||
| resource (DeployResource): The resource information. | |||
| provider (CreateParameter): The cloud service provider parameter | |||
| Raises: | |||
| NotLoginException: To use this api, you need login first. | |||
| NotSupportError: Not supported platform. | |||
| RequestError: The server return error. | |||
| Returns: | |||
| InstanceInfo: The instance information. | |||
| """ | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| if cookies is None: | |||
| raise NotLoginException( | |||
| 'Token does not exist, please login first.') | |||
| if provider.vendor != Vendor.EAS: | |||
| raise NotSupportError( | |||
| 'Not support vendor: %s ,only support EAS current.' % | |||
| (provider.vendor)) | |||
| create_params = DeployServiceParameters( | |||
| instance_name=instance_name, | |||
| model_id=model_id, | |||
| revision=revision, | |||
| resource=resource, | |||
| provider=provider) | |||
| path = f'{self.endpoint}/api/v1/deployer/endpoint' | |||
| body = attrs.asdict(create_params) | |||
| r = requests.post( | |||
| path, | |||
| json=body, | |||
| cookies=cookies, | |||
| ) | |||
| handle_http_response(r, logger, cookies, 'create_eas_instance') | |||
| if r.status_code >= HTTPStatus.OK and r.status_code < HTTPStatus.MULTIPLE_CHOICES: | |||
| if is_ok(r.json()): | |||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||
| return data | |||
| else: | |||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| r.raise_for_status() | |||
| return None | |||
| def list_deployed_model_instances(self, | |||
| provider: ServiceParameters, | |||
| skip: int = 0, | |||
| limit: int = 100): | |||
| """List deployed model instances. | |||
| Args: | |||
| provider (ListServiceParameter): The cloud service provider parameter, | |||
| for eas, need access_key_id and access_key_secret. | |||
| skip: start of the list, current not support. | |||
| limit: maximum number of instances return, current not support | |||
| Raises: | |||
| NotLoginException: To use this api, you need login first. | |||
| RequestError: The request is failed from server. | |||
| Returns: | |||
| List: List of instance information | |||
| """ | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| if cookies is None: | |||
| raise NotLoginException( | |||
| 'Token does not exist, please login first.') | |||
| params = ListServiceParameters( | |||
| provider=provider, skip=skip, limit=limit) | |||
| path = '%s/api/v1/deployer/endpoint?%s' % (self.endpoint, | |||
| params.to_query_str()) | |||
| r = requests.get(path, cookies=cookies) | |||
| handle_http_response(r, logger, cookies, 'list_deployed_model') | |||
| if r.status_code == HTTPStatus.OK: | |||
| if is_ok(r.json()): | |||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||
| return data | |||
| else: | |||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| r.raise_for_status() | |||
| return None | |||
| def get_deployed_model_instance(self, instance_name: str, | |||
| provider: ServiceParameters): | |||
| """Query the specified instance information. | |||
| Args: | |||
| instance_name (str): The deployed instance name. | |||
| provider (GetParameter): The cloud provider information, for eas | |||
| need region(eg: ch-hangzhou), access_key_id and access_key_secret. | |||
| Raises: | |||
| NotLoginException: To use this api, you need login first. | |||
| RequestError: The request is failed from server. | |||
| Returns: | |||
| Dict: The request instance information | |||
| """ | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| if cookies is None: | |||
| raise NotLoginException( | |||
| 'Token does not exist, please login first.') | |||
| params = GetServiceParameters(provider=provider) | |||
| path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | |||
| self.endpoint, instance_name, params.to_query_str()) | |||
| r = requests.get(path, cookies=cookies) | |||
| handle_http_response(r, logger, cookies, 'get_deployed_model') | |||
| if r.status_code == HTTPStatus.OK: | |||
| if is_ok(r.json()): | |||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||
| return data | |||
| else: | |||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| r.raise_for_status() | |||
| return None | |||
| def delete_deployed_model_instance(self, instance_name: str, | |||
| provider: ServiceParameters): | |||
| """Delete deployed model, this api send delete command and return, it will take | |||
| some to delete, please check through the cloud console. | |||
| Args: | |||
| instance_name (str): The instance name you want to delete. | |||
| provider (DeleteParameter): The cloud provider information, for eas | |||
| need region(eg: ch-hangzhou), access_key_id and access_key_secret. | |||
| Raises: | |||
| NotLoginException: To call this api, you need login first. | |||
| RequestError: The request is failed. | |||
| Returns: | |||
| Dict: The deleted instance information. | |||
| """ | |||
| cookies = ModelScopeConfig.get_cookies() | |||
| if cookies is None: | |||
| raise NotLoginException( | |||
| 'Token does not exist, please login first.') | |||
| params = DeleteServiceParameters(provider=provider) | |||
| path = '%s/api/v1/deployer/endpoint/%s?%s' % ( | |||
| self.endpoint, instance_name, params.to_query_str()) | |||
| r = requests.delete(path, cookies=cookies) | |||
| handle_http_response(r, logger, cookies, 'delete_deployed_model') | |||
| if r.status_code == HTTPStatus.OK: | |||
| if is_ok(r.json()): | |||
| data = r.json()[API_RESPONSE_FIELD_DATA] | |||
| return data | |||
| else: | |||
| raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE]) | |||
| else: | |||
| r.raise_for_status() | |||
| return None | |||
| def _check_cookie(self, | |||
| use_cookies: Union[bool, | |||
| CookieJar] = False) -> CookieJar: | |||
| @@ -0,0 +1,189 @@ | |||
| import urllib | |||
| from abc import ABC, abstractmethod | |||
| from typing import Optional, Union | |||
| import json | |||
| from attr import fields | |||
| from attrs import asdict, define, field, validators | |||
| class Accelerator(object): | |||
| CPU = 'cpu' | |||
| GPU = 'gpu' | |||
| class Vendor(object): | |||
| EAS = 'eas' | |||
| class EASRegion(object): | |||
| beijing = 'cn-beijing' | |||
| hangzhou = 'cn-hangzhou' | |||
| class EASCpuInstanceType(object): | |||
| """EAS Cpu Instance TYpe, ref(https://help.aliyun.com/document_detail/144261.html) | |||
| """ | |||
| tiny = 'ecs.c6.2xlarge' | |||
| small = 'ecs.c6.4xlarge' | |||
| medium = 'ecs.c6.6xlarge' | |||
| large = 'ecs.c6.8xlarge' | |||
| class EASGpuInstanceType(object): | |||
| """EAS Cpu Instance TYpe, ref(https://help.aliyun.com/document_detail/144261.html) | |||
| """ | |||
| tiny = 'ecs.gn5-c28g1.7xlarge' | |||
| small = 'ecs.gn5-c8g1.4xlarge' | |||
| medium = 'ecs.gn6i-c24g1.12xlarge' | |||
| large = 'ecs.gn6e-c12g1.3xlarge' | |||
| def min_smaller_than_max(instance, attribute, value): | |||
| if value > instance.max_replica: | |||
| raise ValueError( | |||
| "'min_replica' value: %s has to be smaller than 'max_replica' value: %s!" | |||
| % (value, instance.max_replica)) | |||
| @define | |||
| class ServiceScalingConfig(object): | |||
| """Resource scaling config | |||
| Currently we ignore max_replica | |||
| Args: | |||
| max_replica: maximum replica | |||
| min_replica: minimum replica | |||
| """ | |||
| max_replica: int = field(default=1, validator=validators.ge(1)) | |||
| min_replica: int = field( | |||
| default=1, validator=[validators.ge(1), min_smaller_than_max]) | |||
| @define | |||
| class ServiceResourceConfig(object): | |||
| """Eas Resource request. | |||
| Args: | |||
| accelerator: the accelerator(cpu|gpu) | |||
| instance_type: the instance type. | |||
| scaling: The instance scaling config. | |||
| """ | |||
| instance_type: str | |||
| scaling: ServiceScalingConfig | |||
| accelerator: str = field( | |||
| default=Accelerator.CPU, | |||
| validator=validators.in_([Accelerator.CPU, Accelerator.GPU])) | |||
| @define | |||
| class ServiceParameters(ABC): | |||
| pass | |||
| @define | |||
| class EASDeployParameters(ServiceParameters): | |||
| """Parameters for EAS Deployment. | |||
| Args: | |||
| resource_group: the resource group to deploy, current default. | |||
| region: The eas instance region(eg: cn-hangzhou). | |||
| access_key_id: The eas account access key id. | |||
| access_key_secret: The eas account access key secret. | |||
| vendor: must be 'eas' | |||
| """ | |||
| region: str | |||
| access_key_id: str | |||
| access_key_secret: str | |||
| resource_group: Optional[str] = None | |||
| vendor: str = field( | |||
| default=Vendor.EAS, validator=validators.in_([Vendor.EAS])) | |||
| """ | |||
| def __init__(self, | |||
| instance_name: str, | |||
| access_key_id: str, | |||
| access_key_secret: str, | |||
| region = EASRegion.beijing, | |||
| instance_type: str = EASCpuInstances.small, | |||
| accelerator: str = Accelerator.CPU, | |||
| resource_group: Optional[str] = None, | |||
| scaling: Optional[str] = None): | |||
| self.instance_name=instance_name | |||
| self.access_key_id=self.access_key_id | |||
| self.access_key_secret = access_key_secret | |||
| self.region = region | |||
| self.instance_type = instance_type | |||
| self.accelerator = accelerator | |||
| self.resource_group = resource_group | |||
| self.scaling = scaling | |||
| """ | |||
| @define | |||
| class EASListParameters(ServiceParameters): | |||
| """EAS instance list parameters. | |||
| Args: | |||
| resource_group: the resource group to deploy, current default. | |||
| region: The eas instance region(eg: cn-hangzhou). | |||
| access_key_id: The eas account access key id. | |||
| access_key_secret: The eas account access key secret. | |||
| vendor: must be 'eas' | |||
| """ | |||
| access_key_id: str | |||
| access_key_secret: str | |||
| region: str = None | |||
| resource_group: str = None | |||
| vendor: str = field( | |||
| default=Vendor.EAS, validator=validators.in_([Vendor.EAS])) | |||
| @define | |||
| class DeployServiceParameters(object): | |||
| """Deploy service parameters | |||
| Args: | |||
| instance_name: the name of the service. | |||
| model_id: the modelscope model_id | |||
| revision: the modelscope model revision | |||
| resource: the resource requirement. | |||
| provider: the cloud service provider. | |||
| """ | |||
| instance_name: str | |||
| model_id: str | |||
| revision: str | |||
| resource: ServiceResourceConfig | |||
| provider: ServiceParameters | |||
| class AttrsToQueryString(ABC): | |||
| """Convert the attrs class to json string. | |||
| Args: | |||
| """ | |||
| def to_query_str(self): | |||
| self_dict = asdict( | |||
| self.provider, filter=lambda attr, value: value is not None) | |||
| json_str = json.dumps(self_dict) | |||
| print(json_str) | |||
| safe_str = urllib.parse.quote_plus(json_str) | |||
| print(safe_str) | |||
| query_param = 'provider=%s' % safe_str | |||
| return query_param | |||
| @define | |||
| class ListServiceParameters(AttrsToQueryString): | |||
| provider: ServiceParameters | |||
| skip: int = 0 | |||
| limit: int = 100 | |||
| @define | |||
| class GetServiceParameters(AttrsToQueryString): | |||
| provider: ServiceParameters | |||
| @define | |||
| class DeleteServiceParameters(AttrsToQueryString): | |||
| provider: ServiceParameters | |||
| @@ -9,6 +9,10 @@ from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| class NotSupportError(Exception): | |||
| pass | |||
| class NotExistError(Exception): | |||
| pass | |||
| @@ -66,6 +70,7 @@ def handle_http_response(response, logger, cookies, model_id): | |||
| logger.error( | |||
| f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \ | |||
| private. Please login first.') | |||
| logger.error('Response details: %s' % response.content) | |||
| raise error | |||
| @@ -67,8 +67,9 @@ class Models(object): | |||
| space_dst = 'space-dst' | |||
| space_intent = 'space-intent' | |||
| space_modeling = 'space-modeling' | |||
| star = 'star' | |||
| star3 = 'star3' | |||
| space_T_en = 'space-T-en' | |||
| space_T_cn = 'space-T-cn' | |||
| tcrf = 'transformer-crf' | |||
| transformer_softmax = 'transformer-softmax' | |||
| lcrf = 'lstm-crf' | |||
| @@ -16,6 +16,7 @@ from modelscope.models.builder import MODELS | |||
| from modelscope.preprocessors import LoadImage | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from .utils import timestamp_format | |||
| from .yolox.data.data_augment import ValTransform | |||
| from .yolox.exp import get_exp_by_name | |||
| from .yolox.utils import postprocess | |||
| @@ -99,14 +100,17 @@ class RealtimeVideoDetector(TorchModel): | |||
| def inference_video(self, v_path): | |||
| outputs = [] | |||
| desc = 'Detecting video: {}'.format(v_path) | |||
| for frame, result in tqdm( | |||
| self.inference_video_iter(v_path), desc=desc): | |||
| for frame_idx, (frame, result) in enumerate( | |||
| tqdm(self.inference_video_iter(v_path), desc=desc)): | |||
| result = result + (timestamp_format(seconds=frame_idx | |||
| / self.fps), ) | |||
| outputs.append(result) | |||
| return outputs | |||
| def inference_video_iter(self, v_path): | |||
| capture = cv2.VideoCapture(v_path) | |||
| self.fps = capture.get(cv2.CAP_PROP_FPS) | |||
| while capture.isOpened(): | |||
| ret, frame = capture.read() | |||
| if not ret: | |||
| @@ -0,0 +1,9 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import math | |||
| def timestamp_format(seconds): | |||
| m, s = divmod(seconds, 60) | |||
| h, m = divmod(m, 60) | |||
| time = '%02d:%02d:%06.3f' % (h, m, s) | |||
| return time | |||
| @@ -24,8 +24,8 @@ import json | |||
| logger = logging.getLogger(__name__) | |||
| class Star3Config(object): | |||
| """Configuration class to store the configuration of a `Star3Model`. | |||
| class SpaceTCnConfig(object): | |||
| """Configuration class to store the configuration of a `SpaceTCnModel`. | |||
| """ | |||
| def __init__(self, | |||
| @@ -40,10 +40,10 @@ class Star3Config(object): | |||
| max_position_embeddings=512, | |||
| type_vocab_size=2, | |||
| initializer_range=0.02): | |||
| """Constructs Star3Config. | |||
| """Constructs SpaceTCnConfig. | |||
| Args: | |||
| vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `Star3Model`. | |||
| vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `SpaceTCnConfig`. | |||
| hidden_size: Size of the encoder layers and the pooler layer. | |||
| num_hidden_layers: Number of hidden layers in the Transformer encoder. | |||
| num_attention_heads: Number of attention heads for each attention layer in | |||
| @@ -59,7 +59,7 @@ class Star3Config(object): | |||
| max_position_embeddings: The maximum sequence length that this model might | |||
| ever be used with. Typically set this to something large just in case | |||
| (e.g., 512 or 1024 or 2048). | |||
| type_vocab_size: The vocabulary size of the `token_type_ids` passed into `Star3Model`. | |||
| type_vocab_size: The vocabulary size of the `token_type_ids` passed into `SpaceTCnConfig`. | |||
| initializer_range: The sttdev of the truncated_normal_initializer for | |||
| initializing all weight matrices. | |||
| """ | |||
| @@ -89,15 +89,15 @@ class Star3Config(object): | |||
| @classmethod | |||
| def from_dict(cls, json_object): | |||
| """Constructs a `Star3Config` from a Python dictionary of parameters.""" | |||
| config = Star3Config(vocab_size_or_config_json_file=-1) | |||
| """Constructs a `SpaceTCnConfig` from a Python dictionary of parameters.""" | |||
| config = SpaceTCnConfig(vocab_size_or_config_json_file=-1) | |||
| for key, value in json_object.items(): | |||
| config.__dict__[key] = value | |||
| return config | |||
| @classmethod | |||
| def from_json_file(cls, json_file): | |||
| """Constructs a `Star3Config` from a json file of parameters.""" | |||
| """Constructs a `SpaceTCnConfig` from a json file of parameters.""" | |||
| with open(json_file, 'r', encoding='utf-8') as reader: | |||
| text = reader.read() | |||
| return cls.from_dict(json.loads(text)) | |||
| @@ -27,7 +27,8 @@ import numpy as np | |||
| import torch | |||
| from torch import nn | |||
| from modelscope.models.nlp.star3.configuration_star3 import Star3Config | |||
| from modelscope.models.nlp.space_T_cn.configuration_space_T_cn import \ | |||
| SpaceTCnConfig | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.logger import get_logger | |||
| @@ -609,9 +610,9 @@ class PreTrainedBertModel(nn.Module): | |||
| def __init__(self, config, *inputs, **kwargs): | |||
| super(PreTrainedBertModel, self).__init__() | |||
| if not isinstance(config, Star3Config): | |||
| if not isinstance(config, SpaceTCnConfig): | |||
| raise ValueError( | |||
| 'Parameter config in `{}(config)` should be an instance of class `Star3Config`. ' | |||
| 'Parameter config in `{}(config)` should be an instance of class `SpaceTCnConfig`. ' | |||
| 'To create a model from a Google pretrained model use ' | |||
| '`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`'.format( | |||
| self.__class__.__name__, self.__class__.__name__)) | |||
| @@ -676,7 +677,7 @@ class PreTrainedBertModel(nn.Module): | |||
| serialization_dir = tempdir | |||
| # Load config | |||
| config_file = os.path.join(serialization_dir, CONFIG_NAME) | |||
| config = Star3Config.from_json_file(config_file) | |||
| config = SpaceTCnConfig.from_json_file(config_file) | |||
| logger.info('Model config {}'.format(config)) | |||
| # Instantiate model. | |||
| model = cls(config, *inputs, **kwargs) | |||
| @@ -742,11 +743,11 @@ class PreTrainedBertModel(nn.Module): | |||
| return model | |||
| class Star3Model(PreTrainedBertModel): | |||
| """Star3Model model ("Bidirectional Embedding Representations from a Transformer pretrained on STAR3.0"). | |||
| class SpaceTCnModel(PreTrainedBertModel): | |||
| """SpaceTCnModel model ("Bidirectional Embedding Representations from a Transformer pretrained on STAR-T-CN"). | |||
| Params: | |||
| config: a Star3Config class instance with the configuration to build a new model | |||
| config: a SpaceTCnConfig class instance with the configuration to build a new model | |||
| Inputs: | |||
| `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] | |||
| @@ -780,16 +781,16 @@ class Star3Model(PreTrainedBertModel): | |||
| input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | |||
| token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) | |||
| config = modeling.Star3Config(vocab_size_or_config_json_file=32000, hidden_size=768, | |||
| config = modeling.SpaceTCnConfig(vocab_size_or_config_json_file=32000, hidden_size=768, | |||
| num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | |||
| model = modeling.Star3Model(config=config) | |||
| model = modeling.SpaceTCnModel(config=config) | |||
| all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) | |||
| ``` | |||
| """ | |||
| def __init__(self, config, schema_link_module='none'): | |||
| super(Star3Model, self).__init__(config) | |||
| super(SpaceTCnModel, self).__init__(config) | |||
| self.embeddings = BertEmbeddings(config) | |||
| self.encoder = BertEncoder( | |||
| config, schema_link_module=schema_link_module) | |||
| @@ -20,7 +20,7 @@ __all__ = ['StarForTextToSql'] | |||
| @MODELS.register_module( | |||
| Tasks.conversational_text_to_sql, module_name=Models.star) | |||
| Tasks.table_question_answering, module_name=Models.space_T_en) | |||
| class StarForTextToSql(Model): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| @@ -3,27 +3,25 @@ | |||
| import os | |||
| from typing import Dict | |||
| import json | |||
| import numpy | |||
| import torch | |||
| import torch.nn.functional as F | |||
| import tqdm | |||
| from transformers import BertTokenizer | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import Model, Tensor | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.nlp.star3.configuration_star3 import Star3Config | |||
| from modelscope.models.nlp.star3.modeling_star3 import Seq2SQL, Star3Model | |||
| from modelscope.preprocessors.star3.fields.struct import Constant | |||
| from modelscope.preprocessors.space_T_cn.fields.struct import Constant | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.device import verify_device | |||
| from .space_T_cn.configuration_space_T_cn import SpaceTCnConfig | |||
| from .space_T_cn.modeling_space_T_cn import Seq2SQL, SpaceTCnModel | |||
| __all__ = ['TableQuestionAnswering'] | |||
| @MODELS.register_module( | |||
| Tasks.table_question_answering, module_name=Models.star3) | |||
| Tasks.table_question_answering, module_name=Models.space_T_cn) | |||
| class TableQuestionAnswering(Model): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| @@ -43,9 +41,9 @@ class TableQuestionAnswering(Model): | |||
| os.path.join(self.model_dir, ModelFile.TORCH_MODEL_BIN_FILE), | |||
| map_location='cpu') | |||
| self.backbone_config = Star3Config.from_json_file( | |||
| self.backbone_config = SpaceTCnConfig.from_json_file( | |||
| os.path.join(self.model_dir, ModelFile.CONFIGURATION)) | |||
| self.backbone_model = Star3Model( | |||
| self.backbone_model = SpaceTCnModel( | |||
| config=self.backbone_config, schema_link_module='rat') | |||
| self.backbone_model.load_state_dict(state_dict['backbone_model']) | |||
| @@ -606,21 +606,12 @@ TASK_OUTPUTS = { | |||
| # } | |||
| Tasks.task_oriented_conversation: [OutputKeys.OUTPUT], | |||
| # conversational text-to-sql result for single sample | |||
| # { | |||
| # "text": "SELECT shop.Name FROM shop." | |||
| # } | |||
| Tasks.conversational_text_to_sql: [OutputKeys.TEXT], | |||
| # table-question-answering result for single sample | |||
| # { | |||
| # "sql": "SELECT shop.Name FROM shop." | |||
| # "sql_history": {sel: 0, agg: 0, conds: [[0, 0, 'val']]} | |||
| # } | |||
| Tasks.table_question_answering: [ | |||
| OutputKeys.SQL_STRING, OutputKeys.SQL_QUERY, OutputKeys.HISTORY, | |||
| OutputKeys.QUERT_RESULT | |||
| ], | |||
| Tasks.table_question_answering: [OutputKeys.OUTPUT], | |||
| # ============ audio tasks =================== | |||
| # asr result for single sample | |||
| @@ -69,9 +69,6 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| 'damo/nlp_space_dialog-modeling'), | |||
| Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, | |||
| 'damo/nlp_space_dialog-state-tracking'), | |||
| Tasks.conversational_text_to_sql: | |||
| (Pipelines.conversational_text_to_sql, | |||
| 'damo/nlp_star_conversational-text-to-sql'), | |||
| Tasks.table_question_answering: | |||
| (Pipelines.table_question_answering_pipeline, | |||
| 'damo/nlp-convai-text2sql-pretrain-cn'), | |||
| @@ -113,9 +113,8 @@ class AnimalRecognitionPipeline(Pipeline): | |||
| label_mapping = f.readlines() | |||
| score = torch.max(inputs['outputs']) | |||
| inputs = { | |||
| OutputKeys.SCORES: | |||
| score.item(), | |||
| OutputKeys.SCORES: [score.item()], | |||
| OutputKeys.LABELS: | |||
| label_mapping[inputs['outputs'].argmax()].split('\t')[1] | |||
| [label_mapping[inputs['outputs'].argmax()].split('\t')[1]] | |||
| } | |||
| return inputs | |||
| @@ -114,9 +114,8 @@ class GeneralRecognitionPipeline(Pipeline): | |||
| label_mapping = f.readlines() | |||
| score = torch.max(inputs['outputs']) | |||
| inputs = { | |||
| OutputKeys.SCORES: | |||
| score.item(), | |||
| OutputKeys.SCORES: [score.item()], | |||
| OutputKeys.LABELS: | |||
| label_mapping[inputs['outputs'].argmax()].split('\t')[1] | |||
| [label_mapping[inputs['outputs'].argmax()].split('\t')[1]] | |||
| } | |||
| return inputs | |||
| @@ -45,15 +45,17 @@ class RealtimeVideoObjectDetectionPipeline(Pipeline): | |||
| **kwargs) -> str: | |||
| forward_output = input['forward_output'] | |||
| scores, boxes, labels = [], [], [] | |||
| scores, boxes, labels, timestamps = [], [], [], [] | |||
| for result in forward_output: | |||
| box, score, label = result | |||
| box, score, label, timestamp = result | |||
| scores.append(score) | |||
| boxes.append(box) | |||
| labels.append(label) | |||
| timestamps.append(timestamp) | |||
| return { | |||
| OutputKeys.BOXES: boxes, | |||
| OutputKeys.SCORES: scores, | |||
| OutputKeys.LABELS: labels, | |||
| OutputKeys.TIMESTAMPS: timestamps, | |||
| } | |||
| @@ -19,7 +19,7 @@ __all__ = ['ConversationalTextToSqlPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.conversational_text_to_sql, | |||
| Tasks.table_question_answering, | |||
| module_name=Pipelines.conversational_text_to_sql) | |||
| class ConversationalTextToSqlPipeline(Pipeline): | |||
| @@ -62,7 +62,7 @@ class ConversationalTextToSqlPipeline(Pipeline): | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| sql = Example.evaluator.obtain_sql(inputs['predict'][0], inputs['db']) | |||
| result = {OutputKeys.TEXT: sql} | |||
| result = {OutputKeys.OUTPUT: {OutputKeys.TEXT: sql}} | |||
| return result | |||
| def _collate_fn(self, data): | |||
| @@ -13,8 +13,9 @@ from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import TableQuestionAnsweringPreprocessor | |||
| from modelscope.preprocessors.star3.fields.database import Database | |||
| from modelscope.preprocessors.star3.fields.struct import Constant, SQLQuery | |||
| from modelscope.preprocessors.space_T_cn.fields.database import Database | |||
| from modelscope.preprocessors.space_T_cn.fields.struct import (Constant, | |||
| SQLQuery) | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| __all__ = ['TableQuestionAnsweringPipeline'] | |||
| @@ -320,7 +321,7 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||
| OutputKeys.QUERT_RESULT: tabledata, | |||
| } | |||
| return output | |||
| return {OutputKeys.OUTPUT: output} | |||
| def _collate_fn(self, data): | |||
| return data | |||
| @@ -40,7 +40,7 @@ if TYPE_CHECKING: | |||
| DialogStateTrackingPreprocessor) | |||
| from .video import ReadVideoData, MovieSceneSegmentationPreprocessor | |||
| from .star import ConversationalTextToSqlPreprocessor | |||
| from .star3 import TableQuestionAnsweringPreprocessor | |||
| from .space_T_cn import TableQuestionAnsweringPreprocessor | |||
| else: | |||
| _import_structure = { | |||
| @@ -81,7 +81,7 @@ else: | |||
| 'DialogStateTrackingPreprocessor', 'InputFeatures' | |||
| ], | |||
| 'star': ['ConversationalTextToSqlPreprocessor'], | |||
| 'star3': ['TableQuestionAnsweringPreprocessor'], | |||
| 'space_T_cn': ['TableQuestionAnsweringPreprocessor'], | |||
| } | |||
| import sys | |||
| @@ -4,7 +4,7 @@ import sqlite3 | |||
| import json | |||
| import tqdm | |||
| from modelscope.preprocessors.star3.fields.struct import Trie | |||
| from modelscope.preprocessors.space_T_cn.fields.struct import Trie | |||
| class Database: | |||
| @@ -1,7 +1,7 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import re | |||
| from modelscope.preprocessors.star3.fields.struct import TypeInfo | |||
| from modelscope.preprocessors.space_T_cn.fields.struct import TypeInfo | |||
| class SchemaLinker: | |||
| @@ -8,8 +8,8 @@ from transformers import BertTokenizer | |||
| from modelscope.metainfo import Preprocessors | |||
| from modelscope.preprocessors.base import Preprocessor | |||
| from modelscope.preprocessors.builder import PREPROCESSORS | |||
| from modelscope.preprocessors.star3.fields.database import Database | |||
| from modelscope.preprocessors.star3.fields.schema_link import SchemaLinker | |||
| from modelscope.preprocessors.space_T_cn.fields.database import Database | |||
| from modelscope.preprocessors.space_T_cn.fields.schema_link import SchemaLinker | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import Fields, ModelFile | |||
| from modelscope.utils.type_assert import type_assert | |||
| @@ -123,7 +123,6 @@ class NLPTasks(object): | |||
| backbone = 'backbone' | |||
| text_error_correction = 'text-error-correction' | |||
| faq_question_answering = 'faq-question-answering' | |||
| conversational_text_to_sql = 'conversational-text-to-sql' | |||
| information_extraction = 'information-extraction' | |||
| document_segmentation = 'document-segmentation' | |||
| feature_extraction = 'feature-extraction' | |||
| @@ -20,7 +20,7 @@ def text2sql_tracking_and_print_results( | |||
| results = p(case) | |||
| print({'question': item}) | |||
| print(results) | |||
| last_sql = results['text'] | |||
| last_sql = results[OutputKeys.OUTPUT][OutputKeys.TEXT] | |||
| history.append(item) | |||
| @@ -1,4 +1,5 @@ | |||
| addict | |||
| attrs | |||
| datasets | |||
| easydict | |||
| einops | |||
| @@ -16,7 +16,7 @@ from modelscope.utils.test_utils import test_level | |||
| class ConversationalTextToSql(unittest.TestCase, DemoCompatibilityCheck): | |||
| def setUp(self) -> None: | |||
| self.task = Tasks.conversational_text_to_sql | |||
| self.task = Tasks.table_question_answering | |||
| self.model_id = 'damo/nlp_star_conversational-text-to-sql' | |||
| model_id = 'damo/nlp_star_conversational-text-to-sql' | |||
| @@ -66,11 +66,6 @@ class ConversationalTextToSql(unittest.TestCase, DemoCompatibilityCheck): | |||
| pipelines = [pipeline(task=self.task, model=self.model_id)] | |||
| text2sql_tracking_and_print_results(self.test_case, pipelines) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipelines = [pipeline(task=self.task)] | |||
| text2sql_tracking_and_print_results(self.test_case, pipelines) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_demo_compatibility(self): | |||
| self.compatibility_check() | |||
| @@ -12,7 +12,7 @@ from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline | |||
| from modelscope.preprocessors import TableQuestionAnsweringPreprocessor | |||
| from modelscope.preprocessors.star3.fields.database import Database | |||
| from modelscope.preprocessors.space_T_cn.fields.database import Database | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| @@ -38,7 +38,7 @@ def tableqa_tracking_and_print_results_with_history( | |||
| output_dict = p({ | |||
| 'question': question, | |||
| 'history_sql': historical_queries | |||
| }) | |||
| })[OutputKeys.OUTPUT] | |||
| print('question', question) | |||
| print('sql text:', output_dict[OutputKeys.SQL_STRING]) | |||
| print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | |||
| @@ -61,7 +61,7 @@ def tableqa_tracking_and_print_results_without_history( | |||
| } | |||
| for p in pipelines: | |||
| for question in test_case['utterance']: | |||
| output_dict = p({'question': question}) | |||
| output_dict = p({'question': question})[OutputKeys.OUTPUT] | |||
| print('question', question) | |||
| print('sql text:', output_dict[OutputKeys.SQL_STRING]) | |||
| print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | |||
| @@ -92,7 +92,7 @@ def tableqa_tracking_and_print_results_with_tableid( | |||
| 'question': question, | |||
| 'table_id': table_id, | |||
| 'history_sql': historical_queries | |||
| }) | |||
| })[OutputKeys.OUTPUT] | |||
| print('question', question) | |||
| print('sql text:', output_dict[OutputKeys.SQL_STRING]) | |||
| print('sql query:', output_dict[OutputKeys.SQL_QUERY]) | |||
| @@ -147,11 +147,6 @@ class TableQuestionAnswering(unittest.TestCase): | |||
| ] | |||
| tableqa_tracking_and_print_results_with_tableid(pipelines) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_task(self): | |||
| pipelines = [pipeline(Tasks.table_question_answering, self.model_id)] | |||
| tableqa_tracking_and_print_results_with_history(pipelines) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub_with_other_classes(self): | |||
| model = Model.from_pretrained(self.model_id) | |||