|
- # Copyright (c) Alibaba, Inc. and its affiliates.
-
- import os
- from typing import Dict, Optional
-
- import torch
- import torch.nn as nn
- from text2sql_lgesql.asdl.asdl import ASDLGrammar
- from text2sql_lgesql.asdl.transition_system import TransitionSystem
- from text2sql_lgesql.model.model_constructor import Text2SQL
- from text2sql_lgesql.utils.constants import GRAMMAR_FILEPATH
-
- from modelscope.metainfo import Models
- from modelscope.models.base import Model, Tensor
- from modelscope.models.builder import MODELS
- from modelscope.utils.config import Config
- from modelscope.utils.constant import ModelFile, Tasks
-
- __all__ = ['StarForTextToSql']
-
-
- @MODELS.register_module(
- Tasks.table_question_answering, module_name=Models.space_T_en)
- class StarForTextToSql(Model):
-
- def __init__(self, model_dir: str, *args, **kwargs):
- """initialize the star model from the `model_dir` path.
-
- Args:
- model_dir (str): the model path.
- """
- super().__init__(model_dir, *args, **kwargs)
- self.beam_size = 5
- self.config = kwargs.pop(
- 'config',
- Config.from_file(
- os.path.join(self.model_dir, ModelFile.CONFIGURATION)))
- self.config.model.model_dir = model_dir
- self.grammar = ASDLGrammar.from_filepath(
- os.path.join(model_dir, 'sql_asdl_v2.txt'))
- self.trans = TransitionSystem.get_class_by_lang('sql')(self.grammar)
- self.arg = self.config.model
- self.device = 'cuda' if \
- ('device' not in kwargs or kwargs['device'] == 'gpu') \
- and torch.cuda.is_available() else 'cpu'
- self.model = Text2SQL(self.arg, self.trans)
- check_point = torch.load(
- open(
- os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), 'rb'),
- map_location=self.device)
- self.model.load_state_dict(check_point['model'])
-
- def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
- """return the result by the model
-
- Args:
- input (Dict[str, Tensor]): the preprocessed data
-
- Returns:
- Dict[str, Tensor]: results
- Example:
- """
- self.model.eval()
- hyps = self.model.parse(input['batch'], self.beam_size) #
- db = input['batch'].examples[0].db
-
- predict = {'predict': hyps, 'db': db}
- return predict
|