commit nlp_convai_text2sql_pretrain_cn inference process to modelscope
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10025155
master
| @@ -55,6 +55,7 @@ class Models(object): | |||||
| space_intent = 'space-intent' | space_intent = 'space-intent' | ||||
| space_modeling = 'space-modeling' | space_modeling = 'space-modeling' | ||||
| star = 'star' | star = 'star' | ||||
| star3 = 'star3' | |||||
| tcrf = 'transformer-crf' | tcrf = 'transformer-crf' | ||||
| transformer_softmax = 'transformer-softmax' | transformer_softmax = 'transformer-softmax' | ||||
| lcrf = 'lstm-crf' | lcrf = 'lstm-crf' | ||||
| @@ -193,6 +194,7 @@ class Pipelines(object): | |||||
| plug_generation = 'plug-generation' | plug_generation = 'plug-generation' | ||||
| faq_question_answering = 'faq-question-answering' | faq_question_answering = 'faq-question-answering' | ||||
| conversational_text_to_sql = 'conversational-text-to-sql' | conversational_text_to_sql = 'conversational-text-to-sql' | ||||
| table_question_answering_pipeline = 'table-question-answering-pipeline' | |||||
| sentence_embedding = 'sentence-embedding' | sentence_embedding = 'sentence-embedding' | ||||
| passage_ranking = 'passage-ranking' | passage_ranking = 'passage-ranking' | ||||
| relation_extraction = 'relation-extraction' | relation_extraction = 'relation-extraction' | ||||
| @@ -296,6 +298,7 @@ class Preprocessors(object): | |||||
| fill_mask_ponet = 'fill-mask-ponet' | fill_mask_ponet = 'fill-mask-ponet' | ||||
| faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' | faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' | ||||
| conversational_text_to_sql = 'conversational-text-to-sql' | conversational_text_to_sql = 'conversational-text-to-sql' | ||||
| table_question_answering_preprocessor = 'table-question-answering-preprocessor' | |||||
| re_tokenizer = 're-tokenizer' | re_tokenizer = 're-tokenizer' | ||||
| document_segmentation = 'document-segmentation' | document_segmentation = 'document-segmentation' | ||||
| @@ -24,6 +24,7 @@ if TYPE_CHECKING: | |||||
| from .space import SpaceForDialogIntent | from .space import SpaceForDialogIntent | ||||
| from .space import SpaceForDialogModeling | from .space import SpaceForDialogModeling | ||||
| from .space import SpaceForDialogStateTracking | from .space import SpaceForDialogStateTracking | ||||
| from .table_question_answering import TableQuestionAnswering | |||||
| from .task_models import (InformationExtractionModel, | from .task_models import (InformationExtractionModel, | ||||
| SequenceClassificationModel, | SequenceClassificationModel, | ||||
| SingleBackboneTaskModelBase, | SingleBackboneTaskModelBase, | ||||
| @@ -64,6 +65,7 @@ else: | |||||
| 'SingleBackboneTaskModelBase', 'TokenClassificationModel' | 'SingleBackboneTaskModelBase', 'TokenClassificationModel' | ||||
| ], | ], | ||||
| 'token_classification': ['SbertForTokenClassification'], | 'token_classification': ['SbertForTokenClassification'], | ||||
| 'table_question_answering': ['TableQuestionAnswering'], | |||||
| 'sentence_embedding': ['SentenceEmbedding'], | 'sentence_embedding': ['SentenceEmbedding'], | ||||
| 'passage_ranking': ['PassageRanking'], | 'passage_ranking': ['PassageRanking'], | ||||
| } | } | ||||
| @@ -0,0 +1,128 @@ | |||||
| # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. | |||||
| # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||||
| # Copyright 2021-2022 The Alibaba DAMO Team Authors. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """PyTorch BERT configuration.""" | |||||
| from __future__ import absolute_import, division, print_function | |||||
| import copy | |||||
| import logging | |||||
| import math | |||||
| import os | |||||
| import shutil | |||||
| import tarfile | |||||
| import tempfile | |||||
| from pathlib import Path | |||||
| from typing import Union | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch_scatter | |||||
| from icecream import ic | |||||
| from torch import nn | |||||
| from torch.nn import CrossEntropyLoss | |||||
| logger = logging.getLogger(__name__) | |||||
| class Star3Config(object): | |||||
| """Configuration class to store the configuration of a `Star3Model`. | |||||
| """ | |||||
| def __init__(self, | |||||
| vocab_size_or_config_json_file, | |||||
| hidden_size=768, | |||||
| num_hidden_layers=12, | |||||
| num_attention_heads=12, | |||||
| intermediate_size=3072, | |||||
| hidden_act='gelu', | |||||
| hidden_dropout_prob=0.1, | |||||
| attention_probs_dropout_prob=0.1, | |||||
| max_position_embeddings=512, | |||||
| type_vocab_size=2, | |||||
| initializer_range=0.02): | |||||
| """Constructs Star3Config. | |||||
| Args: | |||||
| vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `Star3Model`. | |||||
| hidden_size: Size of the encoder layers and the pooler layer. | |||||
| num_hidden_layers: Number of hidden layers in the Transformer encoder. | |||||
| num_attention_heads: Number of attention heads for each attention layer in | |||||
| the Transformer encoder. | |||||
| intermediate_size: The size of the "intermediate" (i.e., feed-forward) | |||||
| layer in the Transformer encoder. | |||||
| hidden_act: The non-linear activation function (function or string) in the | |||||
| encoder and pooler. If string, "gelu", "relu" and "swish" are supported. | |||||
| hidden_dropout_prob: The dropout probabilitiy for all fully connected | |||||
| layers in the embeddings, encoder, and pooler. | |||||
| attention_probs_dropout_prob: The dropout ratio for the attention | |||||
| probabilities. | |||||
| max_position_embeddings: The maximum sequence length that this model might | |||||
| ever be used with. Typically set this to something large just in case | |||||
| (e.g., 512 or 1024 or 2048). | |||||
| type_vocab_size: The vocabulary size of the `token_type_ids` passed into `Star3Model`. | |||||
| initializer_range: The sttdev of the truncated_normal_initializer for | |||||
| initializing all weight matrices. | |||||
| """ | |||||
| if isinstance(vocab_size_or_config_json_file, str): | |||||
| with open( | |||||
| vocab_size_or_config_json_file, 'r', | |||||
| encoding='utf-8') as reader: | |||||
| json_config = json.loads(reader.read()) | |||||
| for key, value in json_config.items(): | |||||
| self.__dict__[key] = value | |||||
| elif isinstance(vocab_size_or_config_json_file, int): | |||||
| self.vocab_size = vocab_size_or_config_json_file | |||||
| self.hidden_size = hidden_size | |||||
| self.num_hidden_layers = num_hidden_layers | |||||
| self.num_attention_heads = num_attention_heads | |||||
| self.hidden_act = hidden_act | |||||
| self.intermediate_size = intermediate_size | |||||
| self.hidden_dropout_prob = hidden_dropout_prob | |||||
| self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||||
| self.max_position_embeddings = max_position_embeddings | |||||
| self.type_vocab_size = type_vocab_size | |||||
| self.initializer_range = initializer_range | |||||
| else: | |||||
| raise ValueError( | |||||
| 'First argument must be either a vocabulary size (int)' | |||||
| 'or the path to a pretrained model config file (str)') | |||||
| @classmethod | |||||
| def from_dict(cls, json_object): | |||||
| """Constructs a `Star3Config` from a Python dictionary of parameters.""" | |||||
| config = Star3Config(vocab_size_or_config_json_file=-1) | |||||
| for key, value in json_object.items(): | |||||
| config.__dict__[key] = value | |||||
| return config | |||||
| @classmethod | |||||
| def from_json_file(cls, json_file): | |||||
| """Constructs a `Star3Config` from a json file of parameters.""" | |||||
| with open(json_file, 'r', encoding='utf-8') as reader: | |||||
| text = reader.read() | |||||
| return cls.from_dict(json.loads(text)) | |||||
| def __repr__(self): | |||||
| return str(self.to_json_string()) | |||||
| def to_dict(self): | |||||
| """Serializes this instance to a Python dictionary.""" | |||||
| output = copy.deepcopy(self.__dict__) | |||||
| return output | |||||
| def to_json_string(self): | |||||
| """Serializes this instance to a JSON string.""" | |||||
| return json.dumps(self.to_dict(), indent=2, sort_keys=True) + '\n' | |||||
| @@ -0,0 +1,747 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| from typing import Dict, Optional | |||||
| import numpy | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from transformers import BertTokenizer | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import Model, Tensor | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.models.nlp.star3.configuration_star3 import Star3Config | |||||
| from modelscope.models.nlp.star3.modeling_star3 import Seq2SQL, Star3Model | |||||
| from modelscope.preprocessors.star3.fields.struct import Constant | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.device import verify_device | |||||
| __all__ = ['TableQuestionAnswering'] | |||||
| @MODELS.register_module( | |||||
| Tasks.table_question_answering, module_name=Models.star3) | |||||
| class TableQuestionAnswering(Model): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the table-question-answering model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| """ | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| self.tokenizer = BertTokenizer( | |||||
| os.path.join(model_dir, ModelFile.VOCAB_FILE)) | |||||
| device_name = kwargs.get('device', 'gpu') | |||||
| verify_device(device_name) | |||||
| self._device_name = device_name | |||||
| state_dict = torch.load( | |||||
| os.path.join(self.model_dir, ModelFile.TORCH_MODEL_BIN_FILE), | |||||
| map_location='cpu') | |||||
| self.backbone_config = Star3Config.from_json_file( | |||||
| os.path.join(self.model_dir, ModelFile.CONFIGURATION)) | |||||
| self.backbone_model = Star3Model( | |||||
| config=self.backbone_config, schema_link_module='rat') | |||||
| self.backbone_model.load_state_dict(state_dict['backbone_model']) | |||||
| constant = Constant() | |||||
| self.agg_ops = constant.agg_ops | |||||
| self.cond_ops = constant.cond_ops | |||||
| self.cond_conn_ops = constant.cond_conn_ops | |||||
| self.action_ops = constant.action_ops | |||||
| self.max_select_num = constant.max_select_num | |||||
| self.max_where_num = constant.max_where_num | |||||
| self.col_type_dict = constant.col_type_dict | |||||
| self.schema_link_dict = constant.schema_link_dict | |||||
| n_cond_ops = len(self.cond_ops) | |||||
| n_agg_ops = len(self.agg_ops) | |||||
| n_action_ops = len(self.action_ops) | |||||
| iS = self.backbone_config.hidden_size | |||||
| self.head_model = Seq2SQL(iS, 100, 2, 0.0, n_cond_ops, n_agg_ops, | |||||
| n_action_ops, self.max_select_num, | |||||
| self.max_where_num, self._device_name) | |||||
| self.head_model.load_state_dict(state_dict['head_model'], strict=False) | |||||
| self.backbone_model.to(self._device_name) | |||||
| self.head_model.to(self._device_name) | |||||
| def convert_string(self, pr_wvi, nlu, nlu_tt): | |||||
| convs = [] | |||||
| for b, nlu1 in enumerate(nlu): | |||||
| conv_dict = {} | |||||
| nlu_tt1 = nlu_tt[b] | |||||
| idx = 0 | |||||
| convflag = True | |||||
| for i, ntok in enumerate(nlu_tt1): | |||||
| if idx >= len(nlu1): | |||||
| convflag = False | |||||
| break | |||||
| if ntok.startswith('##'): | |||||
| ntok = ntok.replace('##', '') | |||||
| tok = nlu1[idx:idx + 1].lower() | |||||
| if ntok == tok: | |||||
| conv_dict[i] = [idx, idx + 1] | |||||
| idx += 1 | |||||
| elif ntok == '#': | |||||
| conv_dict[i] = [idx, idx] | |||||
| elif ntok == '[UNK]': | |||||
| conv_dict[i] = [idx, idx + 1] | |||||
| j = i + 1 | |||||
| idx += 1 | |||||
| if idx < len(nlu1) and j < len( | |||||
| nlu_tt1) and nlu_tt1[j] != '[UNK]': | |||||
| while idx < len(nlu1): | |||||
| val = nlu1[idx:idx + 1].lower() | |||||
| if nlu_tt1[j].startswith(val): | |||||
| break | |||||
| idx += 1 | |||||
| conv_dict[i][1] = idx | |||||
| elif tok in ntok: | |||||
| startid = idx | |||||
| idx += 1 | |||||
| while idx < len(nlu1): | |||||
| tok += nlu1[idx:idx + 1].lower() | |||||
| if ntok == tok: | |||||
| conv_dict[i] = [startid, idx + 1] | |||||
| break | |||||
| idx += 1 | |||||
| idx += 1 | |||||
| else: | |||||
| convflag = False | |||||
| conv = [] | |||||
| if convflag: | |||||
| for pr_wvi1 in pr_wvi[b]: | |||||
| s1, e1 = conv_dict[pr_wvi1[0]] | |||||
| s2, e2 = conv_dict[pr_wvi1[1]] | |||||
| newidx = pr_wvi1[1] | |||||
| while newidx + 1 < len( | |||||
| nlu_tt1) and s2 == e2 and nlu_tt1[newidx] == '#': | |||||
| newidx += 1 | |||||
| s2, e2 = conv_dict[newidx] | |||||
| if newidx + 1 < len(nlu_tt1) and nlu_tt1[ | |||||
| newidx + 1].startswith('##'): | |||||
| s2, e2 = conv_dict[newidx + 1] | |||||
| phrase = nlu1[s1:e2] | |||||
| conv.append(phrase) | |||||
| else: | |||||
| for pr_wvi1 in pr_wvi[b]: | |||||
| phrase = ''.join(nlu_tt1[pr_wvi1[0]:pr_wvi1[1] | |||||
| + 1]).replace('##', '') | |||||
| conv.append(phrase) | |||||
| convs.append(conv) | |||||
| return convs | |||||
| def get_fields_info(self, t1s, tables, train=True): | |||||
| nlu, nlu_t, sql_i, q_know, t_know, action, hs_t, types, units, his_sql, schema_link = \ | |||||
| [], [], [], [], [], [], [], [], [], [], [] | |||||
| for t1 in t1s: | |||||
| nlu.append(t1['question']) | |||||
| nlu_t.append(t1['question_tok']) | |||||
| hs_t.append(t1['header_tok']) | |||||
| q_know.append(t1['bertindex_knowledge']) | |||||
| t_know.append(t1['header_knowledge']) | |||||
| types.append(t1['types']) | |||||
| units.append(t1['units']) | |||||
| his_sql.append(t1.get('history_sql', None)) | |||||
| schema_link.append(t1.get('schema_link', [])) | |||||
| if train: | |||||
| action.append(t1.get('action', [0])) | |||||
| sql_i.append(t1['sql']) | |||||
| return nlu, nlu_t, sql_i, q_know, t_know, action, hs_t, types, units, his_sql, schema_link | |||||
| def get_history_select_where(self, his_sql, header_len): | |||||
| if his_sql is None: | |||||
| return [0], [0] | |||||
| sel = [] | |||||
| for seli in his_sql['sel']: | |||||
| if seli + 1 < header_len and seli + 1 not in sel: | |||||
| sel.append(seli + 1) | |||||
| whe = [] | |||||
| for condi in his_sql['conds']: | |||||
| if condi[0] + 1 < header_len and condi[0] + 1 not in whe: | |||||
| whe.append(condi[0] + 1) | |||||
| if len(sel) == 0: | |||||
| sel.append(0) | |||||
| if len(whe) == 0: | |||||
| whe.append(0) | |||||
| sel.sort() | |||||
| whe.sort() | |||||
| return sel, whe | |||||
| def get_types_ids(self, col_type): | |||||
| for key, type_ids in self.col_type_dict.items(): | |||||
| if key in col_type.lower(): | |||||
| return type_ids | |||||
| return self.col_type_dict['null'] | |||||
| def generate_inputs(self, nlu1_tok, hs_t_1, type_t, unit_t, his_sql, | |||||
| q_know, t_know, s_link): | |||||
| tokens = [] | |||||
| orders = [] | |||||
| types = [] | |||||
| segment_ids = [] | |||||
| matchs = [] | |||||
| col_dict = {} | |||||
| schema_tok = [] | |||||
| tokens.append('[CLS]') | |||||
| orders.append(0) | |||||
| types.append(0) | |||||
| i_st_nlu = len(tokens) | |||||
| matchs.append(0) | |||||
| segment_ids.append(0) | |||||
| for idx, token in enumerate(nlu1_tok): | |||||
| if q_know[idx] == 100: | |||||
| break | |||||
| elif q_know[idx] >= 5: | |||||
| matchs.append(1) | |||||
| else: | |||||
| matchs.append(q_know[idx] + 1) | |||||
| tokens.append(token) | |||||
| orders.append(0) | |||||
| types.append(0) | |||||
| segment_ids.append(0) | |||||
| i_ed_nlu = len(tokens) | |||||
| tokens.append('[SEP]') | |||||
| orders.append(0) | |||||
| types.append(0) | |||||
| matchs.append(0) | |||||
| segment_ids.append(0) | |||||
| sel, whe = self.get_history_select_where(his_sql, len(hs_t_1)) | |||||
| if len(sel) == 1 and sel[0] == 0 \ | |||||
| and len(whe) == 1 and whe[0] == 0: | |||||
| pass | |||||
| else: | |||||
| tokens.append('select') | |||||
| orders.append(0) | |||||
| types.append(0) | |||||
| matchs.append(10) | |||||
| segment_ids.append(0) | |||||
| for seli in sel: | |||||
| tokens.append('[PAD]') | |||||
| orders.append(0) | |||||
| types.append(0) | |||||
| matchs.append(10) | |||||
| segment_ids.append(0) | |||||
| col_dict[len(tokens) - 1] = seli | |||||
| tokens.append('where') | |||||
| orders.append(0) | |||||
| types.append(0) | |||||
| matchs.append(10) | |||||
| segment_ids.append(0) | |||||
| for whei in whe: | |||||
| tokens.append('[PAD]') | |||||
| orders.append(0) | |||||
| types.append(0) | |||||
| matchs.append(10) | |||||
| segment_ids.append(0) | |||||
| col_dict[len(tokens) - 1] = whei | |||||
| tokens.append('[SEP]') | |||||
| orders.append(0) | |||||
| types.append(0) | |||||
| matchs.append(10) | |||||
| segment_ids.append(0) | |||||
| column_start = len(tokens) | |||||
| i_hds_f = [] | |||||
| header_flatten_tokens, header_flatten_index = [], [] | |||||
| for i, hds11 in enumerate(hs_t_1): | |||||
| if len(unit_t[i]) == 1 and unit_t[i][0] == 'null': | |||||
| temp_header_tokens = hds11 | |||||
| else: | |||||
| temp_header_tokens = hds11 + unit_t[i] | |||||
| schema_tok.append(temp_header_tokens) | |||||
| header_flatten_tokens.extend(temp_header_tokens) | |||||
| header_flatten_index.extend([i + 1] * len(temp_header_tokens)) | |||||
| i_st_hd_f = len(tokens) | |||||
| tokens += ['[PAD]'] | |||||
| orders.append(0) | |||||
| types.append(self.get_types_ids(type_t[i])) | |||||
| i_ed_hd_f = len(tokens) | |||||
| col_dict[len(tokens) - 1] = i | |||||
| i_hds_f.append((i_st_hd_f, i_ed_hd_f)) | |||||
| if i == 0: | |||||
| matchs.append(6) | |||||
| else: | |||||
| matchs.append(t_know[i - 1] + 6) | |||||
| segment_ids.append(1) | |||||
| tokens.append('[SEP]') | |||||
| orders.append(0) | |||||
| types.append(0) | |||||
| matchs.append(0) | |||||
| segment_ids.append(1) | |||||
| # position where | |||||
| # [SEP] | |||||
| start_ids = len(tokens) - 1 | |||||
| tokens.append('action') # action | |||||
| orders.append(1) | |||||
| types.append(0) | |||||
| matchs.append(0) | |||||
| segment_ids.append(1) | |||||
| tokens.append('connect') # column | |||||
| orders.append(1) | |||||
| types.append(0) | |||||
| matchs.append(0) | |||||
| segment_ids.append(1) | |||||
| tokens.append('allen') # select len | |||||
| orders.append(1) | |||||
| types.append(0) | |||||
| matchs.append(0) | |||||
| segment_ids.append(1) | |||||
| for x in range(self.max_where_num): | |||||
| tokens.append('act') # op | |||||
| orders.append(2 + x) | |||||
| types.append(0) | |||||
| matchs.append(0) | |||||
| segment_ids.append(1) | |||||
| tokens.append('size') # where len | |||||
| orders.append(1) | |||||
| types.append(0) | |||||
| matchs.append(0) | |||||
| segment_ids.append(1) | |||||
| for x in range(self.max_select_num): | |||||
| tokens.append('focus') # agg | |||||
| orders.append(2 + x) | |||||
| types.append(0) | |||||
| matchs.append(0) | |||||
| segment_ids.append(1) | |||||
| i_nlu = (i_st_nlu, i_ed_nlu) | |||||
| schema_link_matrix = numpy.zeros((len(tokens), len(tokens)), | |||||
| dtype='int32') | |||||
| schema_link_mask = numpy.zeros((len(tokens), len(tokens)), | |||||
| dtype='float32') | |||||
| for relation in s_link: | |||||
| if relation['label'] in ['col', 'val']: | |||||
| [q_st, q_ed] = relation['question_index'] | |||||
| cid = max(0, relation['column_index']) | |||||
| schema_link_matrix[ | |||||
| i_st_nlu + q_st: i_st_nlu + q_ed + 1, | |||||
| column_start + cid + 1: column_start + cid + 1 + 1] = \ | |||||
| self.schema_link_dict[relation['label'] + '_middle'] | |||||
| schema_link_matrix[ | |||||
| i_st_nlu + q_st, | |||||
| column_start + cid + 1: column_start + cid + 1 + 1] = \ | |||||
| self.schema_link_dict[relation['label'] + '_start'] | |||||
| schema_link_matrix[ | |||||
| i_st_nlu + q_ed, | |||||
| column_start + cid + 1: column_start + cid + 1 + 1] = \ | |||||
| self.schema_link_dict[relation['label'] + '_end'] | |||||
| schema_link_mask[i_st_nlu + q_st:i_st_nlu + q_ed + 1, | |||||
| column_start + cid + 1:column_start + cid + 1 | |||||
| + 1] = 1.0 | |||||
| return tokens, orders, types, segment_ids, matchs, \ | |||||
| i_nlu, i_hds_f, start_ids, column_start, col_dict, schema_tok, \ | |||||
| header_flatten_tokens, header_flatten_index, schema_link_matrix, schema_link_mask | |||||
| def gen_l_hpu(self, i_hds): | |||||
| """ | |||||
| Treat columns as if it is a batch of natural language utterance | |||||
| with batch-size = # of columns * # of batch_size | |||||
| i_hds = [(17, 18), (19, 21), (22, 23), (24, 25), (26, 29), (30, 34)]) | |||||
| """ | |||||
| l_hpu = [] | |||||
| for i_hds1 in i_hds: | |||||
| for i_hds11 in i_hds1: | |||||
| l_hpu.append(i_hds11[1] - i_hds11[0]) | |||||
| return l_hpu | |||||
| def get_bert_output(self, model_bert, tokenizer, nlu_t, hs_t, col_types, | |||||
| units, his_sql, q_know, t_know, schema_link): | |||||
| """ | |||||
| Here, input is toknized further by WordPiece (WP) tokenizer and fed into BERT. | |||||
| INPUT | |||||
| :param model_bert: | |||||
| :param tokenizer: WordPiece toknizer | |||||
| :param nlu: Question | |||||
| :param nlu_t: CoreNLP tokenized nlu. | |||||
| :param hds: Headers | |||||
| :param hs_t: None or 1st-level tokenized headers | |||||
| :param max_seq_length: max input token length | |||||
| OUTPUT | |||||
| tokens: BERT input tokens | |||||
| nlu_tt: WP-tokenized input natural language questions | |||||
| orig_to_tok_index: map the index of 1st-level-token to the index of 2nd-level-token | |||||
| tok_to_orig_index: inverse map. | |||||
| """ | |||||
| l_n = [] | |||||
| l_hs = [] # The length of columns for each batch | |||||
| input_ids = [] | |||||
| order_ids = [] | |||||
| type_ids = [] | |||||
| segment_ids = [] | |||||
| match_ids = [] | |||||
| input_mask = [] | |||||
| i_nlu = [ | |||||
| ] # index to retreive the position of contextual vector later. | |||||
| i_hds = [] | |||||
| tokens = [] | |||||
| orders = [] | |||||
| types = [] | |||||
| matchs = [] | |||||
| segments = [] | |||||
| schema_link_matrix_list = [] | |||||
| schema_link_mask_list = [] | |||||
| start_index = [] | |||||
| column_index = [] | |||||
| col_dict_list = [] | |||||
| header_list = [] | |||||
| header_flatten_token_list = [] | |||||
| header_flatten_tokenid_list = [] | |||||
| header_flatten_index_list = [] | |||||
| header_tok_max_len = 0 | |||||
| cur_max_length = 0 | |||||
| for b, nlu_t1 in enumerate(nlu_t): | |||||
| hs_t1 = [hs_t[b][-1]] + hs_t[b][:-1] | |||||
| type_t1 = [col_types[b][-1]] + col_types[b][:-1] | |||||
| unit_t1 = [units[b][-1]] + units[b][:-1] | |||||
| l_hs.append(len(hs_t1)) | |||||
| # [CLS] nlu [SEP] col1 [SEP] col2 [SEP] ...col-n [SEP] | |||||
| # 2. Generate BERT inputs & indices. | |||||
| tokens1, orders1, types1, segment1, match1, i_nlu1, i_hds_1, \ | |||||
| start_idx, column_start, col_dict, schema_tok, \ | |||||
| header_flatten_tokens, header_flatten_index, schema_link_matrix, schema_link_mask = \ | |||||
| self.generate_inputs( | |||||
| nlu_t1, hs_t1, type_t1, unit_t1, his_sql[b], | |||||
| q_know[b], t_know[b], schema_link[b]) | |||||
| l_n.append(i_nlu1[1] - i_nlu1[0]) | |||||
| start_index.append(start_idx) | |||||
| column_index.append(column_start) | |||||
| col_dict_list.append(col_dict) | |||||
| tokens.append(tokens1) | |||||
| orders.append(orders1) | |||||
| types.append(types1) | |||||
| segments.append(segment1) | |||||
| matchs.append(match1) | |||||
| i_nlu.append(i_nlu1) | |||||
| i_hds.append(i_hds_1) | |||||
| schema_link_matrix_list.append(schema_link_matrix) | |||||
| schema_link_mask_list.append(schema_link_mask) | |||||
| header_flatten_token_list.append(header_flatten_tokens) | |||||
| header_flatten_index_list.append(header_flatten_index) | |||||
| header_list.append(schema_tok) | |||||
| header_max = max([len(schema_tok1) for schema_tok1 in schema_tok]) | |||||
| if header_max > header_tok_max_len: | |||||
| header_tok_max_len = header_max | |||||
| if len(tokens1) > cur_max_length: | |||||
| cur_max_length = len(tokens1) | |||||
| if len(tokens1) > 512: | |||||
| print('input too long!!! total_num:%d\t question:%s' % | |||||
| (len(tokens1), ''.join(nlu_t1))) | |||||
| assert cur_max_length <= 512 | |||||
| for i, tokens1 in enumerate(tokens): | |||||
| segment_ids1 = segments[i] | |||||
| order_ids1 = orders[i] | |||||
| type_ids1 = types[i] | |||||
| match_ids1 = matchs[i] | |||||
| input_ids1 = tokenizer.convert_tokens_to_ids(tokens1) | |||||
| input_mask1 = [1] * len(input_ids1) | |||||
| while len(input_ids1) < cur_max_length: | |||||
| input_ids1.append(0) | |||||
| input_mask1.append(0) | |||||
| segment_ids1.append(0) | |||||
| order_ids1.append(0) | |||||
| type_ids1.append(0) | |||||
| match_ids1.append(0) | |||||
| if len(input_ids1) != cur_max_length: | |||||
| print('Error: ', nlu_t1, tokens1, len(input_ids1), | |||||
| cur_max_length) | |||||
| assert len(input_ids1) == cur_max_length | |||||
| assert len(input_mask1) == cur_max_length | |||||
| assert len(order_ids1) == cur_max_length | |||||
| assert len(segment_ids1) == cur_max_length | |||||
| assert len(match_ids1) == cur_max_length | |||||
| assert len(type_ids1) == cur_max_length | |||||
| input_ids.append(input_ids1) | |||||
| order_ids.append(order_ids1) | |||||
| type_ids.append(type_ids1) | |||||
| segment_ids.append(segment_ids1) | |||||
| input_mask.append(input_mask1) | |||||
| match_ids.append(match_ids1) | |||||
| header_len = [] | |||||
| header_ids = [] | |||||
| header_max_len = max( | |||||
| [len(header_list1) for header_list1 in header_list]) | |||||
| for header1 in header_list: | |||||
| header_len1 = [] | |||||
| header_ids1 = [] | |||||
| for header_tok in header1: | |||||
| header_len1.append(len(header_tok)) | |||||
| header_tok_ids1 = tokenizer.convert_tokens_to_ids(header_tok) | |||||
| while len(header_tok_ids1) < header_tok_max_len: | |||||
| header_tok_ids1.append(0) | |||||
| header_ids1.append(header_tok_ids1) | |||||
| while len(header_ids1) < header_max_len: | |||||
| header_ids1.append([0] * header_tok_max_len) | |||||
| header_len.append(header_len1) | |||||
| header_ids.append(header_ids1) | |||||
| for i, header_flatten_token in enumerate(header_flatten_token_list): | |||||
| header_flatten_tokenid = tokenizer.convert_tokens_to_ids( | |||||
| header_flatten_token) | |||||
| header_flatten_tokenid_list.append(header_flatten_tokenid) | |||||
| # Convert to tensor | |||||
| all_input_ids = torch.tensor( | |||||
| input_ids, dtype=torch.long).to(self._device_name) | |||||
| all_order_ids = torch.tensor( | |||||
| order_ids, dtype=torch.long).to(self._device_name) | |||||
| all_type_ids = torch.tensor( | |||||
| type_ids, dtype=torch.long).to(self._device_name) | |||||
| all_input_mask = torch.tensor( | |||||
| input_mask, dtype=torch.long).to(self._device_name) | |||||
| all_segment_ids = torch.tensor( | |||||
| segment_ids, dtype=torch.long).to(self._device_name) | |||||
| all_match_ids = torch.tensor( | |||||
| match_ids, dtype=torch.long).to(self._device_name) | |||||
| all_header_ids = torch.tensor( | |||||
| header_ids, dtype=torch.long).to(self._device_name) | |||||
| all_ids = torch.arange( | |||||
| all_input_ids.shape[0], dtype=torch.long).to(self._device_name) | |||||
| bS = len(header_flatten_tokenid_list) | |||||
| max_header_flatten_token_length = max( | |||||
| [len(x) for x in header_flatten_tokenid_list]) | |||||
| all_header_flatten_tokens = numpy.zeros( | |||||
| (bS, max_header_flatten_token_length), dtype='int32') | |||||
| all_header_flatten_index = numpy.zeros( | |||||
| (bS, max_header_flatten_token_length), dtype='int32') | |||||
| for i, header_flatten_tokenid in enumerate( | |||||
| header_flatten_tokenid_list): | |||||
| for j, tokenid in enumerate(header_flatten_tokenid): | |||||
| all_header_flatten_tokens[i, j] = tokenid | |||||
| for j, hdindex in enumerate(header_flatten_index_list[i]): | |||||
| all_header_flatten_index[i, j] = hdindex | |||||
| all_header_flatten_output = numpy.zeros((bS, header_max_len + 1), | |||||
| dtype='int32') | |||||
| all_header_flatten_tokens = torch.tensor( | |||||
| all_header_flatten_tokens, dtype=torch.long).to(self._device_name) | |||||
| all_header_flatten_index = torch.tensor( | |||||
| all_header_flatten_index, dtype=torch.long).to(self._device_name) | |||||
| all_header_flatten_output = torch.tensor( | |||||
| all_header_flatten_output, | |||||
| dtype=torch.float32).to(self._device_name) | |||||
| all_token_column_id = numpy.zeros((bS, cur_max_length), dtype='int32') | |||||
| all_token_column_mask = numpy.zeros((bS, cur_max_length), | |||||
| dtype='float32') | |||||
| for bi, col_dict in enumerate(col_dict_list): | |||||
| for ki, vi in col_dict.items(): | |||||
| all_token_column_id[bi, ki] = vi + 1 | |||||
| all_token_column_mask[bi, ki] = 1.0 | |||||
| all_token_column_id = torch.tensor( | |||||
| all_token_column_id, dtype=torch.long).to(self._device_name) | |||||
| all_token_column_mask = torch.tensor( | |||||
| all_token_column_mask, dtype=torch.float32).to(self._device_name) | |||||
| all_schema_link_matrix = numpy.zeros( | |||||
| (bS, cur_max_length, cur_max_length), dtype='int32') | |||||
| all_schema_link_mask = numpy.zeros( | |||||
| (bS, cur_max_length, cur_max_length), dtype='float32') | |||||
| for i, schema_link_matrix in enumerate(schema_link_matrix_list): | |||||
| temp_len = schema_link_matrix.shape[0] | |||||
| all_schema_link_matrix[i, 0:temp_len, | |||||
| 0:temp_len] = schema_link_matrix | |||||
| all_schema_link_mask[i, 0:temp_len, | |||||
| 0:temp_len] = schema_link_mask_list[i] | |||||
| all_schema_link_matrix = torch.tensor( | |||||
| all_schema_link_matrix, dtype=torch.long).to(self._device_name) | |||||
| all_schema_link_mask = torch.tensor( | |||||
| all_schema_link_mask, dtype=torch.long).to(self._device_name) | |||||
| # 5. generate l_hpu from i_hds | |||||
| l_hpu = self.gen_l_hpu(i_hds) | |||||
| # 4. Generate BERT output. | |||||
| all_encoder_layer, pooled_output = model_bert( | |||||
| all_input_ids, | |||||
| all_header_ids, | |||||
| token_order_ids=all_order_ids, | |||||
| token_type_ids=all_segment_ids, | |||||
| attention_mask=all_input_mask, | |||||
| match_type_ids=all_match_ids, | |||||
| l_hs=l_hs, | |||||
| header_len=header_len, | |||||
| type_ids=all_type_ids, | |||||
| col_dict_list=col_dict_list, | |||||
| ids=all_ids, | |||||
| header_flatten_tokens=all_header_flatten_tokens, | |||||
| header_flatten_index=all_header_flatten_index, | |||||
| header_flatten_output=all_header_flatten_output, | |||||
| token_column_id=all_token_column_id, | |||||
| token_column_mask=all_token_column_mask, | |||||
| column_start_index=column_index, | |||||
| headers_length=l_hs, | |||||
| all_schema_link_matrix=all_schema_link_matrix, | |||||
| all_schema_link_mask=all_schema_link_mask, | |||||
| output_all_encoded_layers=False) | |||||
| return all_encoder_layer, pooled_output, tokens, i_nlu, i_hds, \ | |||||
| l_n, l_hpu, l_hs, start_index, column_index, all_ids | |||||
| def predict(self, querys): | |||||
| self.head_model.eval() | |||||
| self.backbone_model.eval() | |||||
| nlu, nlu_t, sql_i, q_know, t_know, tb, hs_t, types, units, his_sql, schema_link = \ | |||||
| self.get_fields_info(querys, None, train=False) | |||||
| with torch.no_grad(): | |||||
| all_encoder_layer, _, tokens, i_nlu, i_hds, l_n, l_hpu, l_hs, start_index, column_index, ids = \ | |||||
| self.get_bert_output( | |||||
| self.backbone_model, self.tokenizer, | |||||
| nlu_t, hs_t, types, units, his_sql, q_know, t_know, schema_link) | |||||
| s_action, s_sc, s_sa, s_cco, s_wc, s_wo, s_wvs, s_len = self.head_model( | |||||
| all_encoder_layer, l_n, l_hs, start_index, column_index, | |||||
| tokens, ids) | |||||
| action_batch = torch.argmax(F.softmax(s_action, -1), -1).cpu().tolist() | |||||
| scco_batch = torch.argmax(F.softmax(s_cco, -1), -1).cpu().tolist() | |||||
| sc_batch = torch.argmax(F.softmax(s_sc, -1), -1).cpu().tolist() | |||||
| sa_batch = torch.argmax(F.softmax(s_sa, -1), -1).cpu().tolist() | |||||
| wc_batch = torch.argmax(F.softmax(s_wc, -1), -1).cpu().tolist() | |||||
| wo_batch = torch.argmax(F.softmax(s_wo, -1), -1).cpu().tolist() | |||||
| s_wvs_s, s_wvs_e = s_wvs | |||||
| wvss_batch = torch.argmax(F.softmax(s_wvs_s, -1), -1).cpu().tolist() | |||||
| wvse_batch = torch.argmax(F.softmax(s_wvs_e, -1), -1).cpu().tolist() | |||||
| s_slen, s_wlen = s_len | |||||
| slen_batch = torch.argmax(F.softmax(s_slen, -1), -1).cpu().tolist() | |||||
| wlen_batch = torch.argmax(F.softmax(s_wlen, -1), -1).cpu().tolist() | |||||
| pr_wvi = [] | |||||
| for i in range(len(querys)): | |||||
| wvi = [] | |||||
| for j in range(wlen_batch[i]): | |||||
| wvi.append([ | |||||
| max(0, wvss_batch[i][j] - 1), | |||||
| max(0, wvse_batch[i][j] - 1) | |||||
| ]) | |||||
| pr_wvi.append(wvi) | |||||
| pr_wvi_str = self.convert_string(pr_wvi, nlu, nlu_t) | |||||
| pre_results = [] | |||||
| for ib in range(len(querys)): | |||||
| res_one = {} | |||||
| sql = {} | |||||
| sql['cond_conn_op'] = scco_batch[ib] | |||||
| sl = slen_batch[ib] | |||||
| sql['sel'] = list( | |||||
| numpy.array(sc_batch[ib][:sl]).astype(numpy.int32) - 1) | |||||
| sql['agg'] = list( | |||||
| numpy.array(sa_batch[ib][:sl]).astype(numpy.int32)) | |||||
| sels = [] | |||||
| aggs = [] | |||||
| for ia, sel in enumerate(sql['sel']): | |||||
| if sel == -1: | |||||
| if sql['agg'][ia] > 0: | |||||
| sels.append(l_hs[ib] - 1) | |||||
| aggs.append(sql['agg'][ia]) | |||||
| continue | |||||
| sels.append(sel) | |||||
| if sql['agg'][ia] == -1: | |||||
| aggs.append(0) | |||||
| else: | |||||
| aggs.append(sql['agg'][ia]) | |||||
| if len(sels) == 0: | |||||
| sels.append(l_hs[ib] - 1) | |||||
| aggs.append(0) | |||||
| assert len(sels) == len(aggs) | |||||
| sql['sel'] = sels | |||||
| sql['agg'] = aggs | |||||
| conds = [] | |||||
| wl = wlen_batch[ib] | |||||
| wc_os = list( | |||||
| numpy.array(wc_batch[ib][:wl]).astype(numpy.int32) - 1) | |||||
| wo_os = list(numpy.array(wo_batch[ib][:wl]).astype(numpy.int32)) | |||||
| res_one['question_tok'] = querys[ib]['question_tok'] | |||||
| for i in range(wl): | |||||
| if wc_os[i] == -1: | |||||
| continue | |||||
| conds.append([wc_os[i], wo_os[i], pr_wvi_str[ib][i]]) | |||||
| if len(conds) == 0: | |||||
| conds.append([l_hs[ib] - 1, 2, 'Nulll']) | |||||
| sql['conds'] = conds | |||||
| res_one['question'] = querys[ib]['question'] | |||||
| res_one['table_id'] = querys[ib]['table_id'] | |||||
| res_one['sql'] = sql | |||||
| res_one['action'] = action_batch[ib] | |||||
| res_one['model_out'] = [ | |||||
| sc_batch[ib], sa_batch[ib], wc_batch[ib], wo_batch[ib], | |||||
| wvss_batch[ib], wvse_batch[ib] | |||||
| ] | |||||
| pre_results.append(res_one) | |||||
| return pre_results | |||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||||
| """return the result by the model | |||||
| Args: | |||||
| input (Dict[str, Tensor]): the preprocessed data | |||||
| Returns: | |||||
| Dict[str, Tensor]: results | |||||
| Example: | |||||
| """ | |||||
| result = self.predict(input['datas'])[0] | |||||
| return { | |||||
| 'result': result, | |||||
| 'history_sql': input['datas'][0]['history_sql'] | |||||
| } | |||||
| @@ -35,6 +35,7 @@ class OutputKeys(object): | |||||
| UUID = 'uuid' | UUID = 'uuid' | ||||
| WORD = 'word' | WORD = 'word' | ||||
| KWS_LIST = 'kws_list' | KWS_LIST = 'kws_list' | ||||
| HISTORY = 'history' | |||||
| TIMESTAMPS = 'timestamps' | TIMESTAMPS = 'timestamps' | ||||
| SPLIT_VIDEO_NUM = 'split_video_num' | SPLIT_VIDEO_NUM = 'split_video_num' | ||||
| SPLIT_META_DICT = 'split_meta_dict' | SPLIT_META_DICT = 'split_meta_dict' | ||||
| @@ -471,6 +472,13 @@ TASK_OUTPUTS = { | |||||
| # } | # } | ||||
| Tasks.conversational_text_to_sql: [OutputKeys.TEXT], | Tasks.conversational_text_to_sql: [OutputKeys.TEXT], | ||||
| # table-question-answering result for single sample | |||||
| # { | |||||
| # "sql": "SELECT shop.Name FROM shop." | |||||
| # "sql_history": {sel: 0, agg: 0, conds: [[0, 0, 'val']]} | |||||
| # } | |||||
| Tasks.table_question_answering: [OutputKeys.OUTPUT, OutputKeys.HISTORY], | |||||
| # ============ audio tasks =================== | # ============ audio tasks =================== | ||||
| # asr result for single sample | # asr result for single sample | ||||
| # { "text": "每一天都要快乐喔"} | # { "text": "每一天都要快乐喔"} | ||||
| @@ -66,6 +66,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| Tasks.conversational_text_to_sql: | Tasks.conversational_text_to_sql: | ||||
| (Pipelines.conversational_text_to_sql, | (Pipelines.conversational_text_to_sql, | ||||
| 'damo/nlp_star_conversational-text-to-sql'), | 'damo/nlp_star_conversational-text-to-sql'), | ||||
| Tasks.table_question_answering: | |||||
| (Pipelines.table_question_answering_pipeline, | |||||
| 'damo/nlp-convai-text2sql-pretrain-cn'), | |||||
| Tasks.text_error_correction: | Tasks.text_error_correction: | ||||
| (Pipelines.text_error_correction, | (Pipelines.text_error_correction, | ||||
| 'damo/nlp_bart_text-error-correction_chinese'), | 'damo/nlp_bart_text-error-correction_chinese'), | ||||
| @@ -5,6 +5,7 @@ from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from .conversational_text_to_sql_pipeline import ConversationalTextToSqlPipeline | from .conversational_text_to_sql_pipeline import ConversationalTextToSqlPipeline | ||||
| from .table_question_answering_pipeline import TableQuestionAnsweringPipeline | |||||
| from .dialog_intent_prediction_pipeline import DialogIntentPredictionPipeline | from .dialog_intent_prediction_pipeline import DialogIntentPredictionPipeline | ||||
| from .dialog_modeling_pipeline import DialogModelingPipeline | from .dialog_modeling_pipeline import DialogModelingPipeline | ||||
| from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline | from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline | ||||
| @@ -31,6 +32,8 @@ else: | |||||
| _import_structure = { | _import_structure = { | ||||
| 'conversational_text_to_sql_pipeline': | 'conversational_text_to_sql_pipeline': | ||||
| ['ConversationalTextToSqlPipeline'], | ['ConversationalTextToSqlPipeline'], | ||||
| 'table_question_answering_pipeline': | |||||
| ['TableQuestionAnsweringPipeline'], | |||||
| 'dialog_intent_prediction_pipeline': | 'dialog_intent_prediction_pipeline': | ||||
| ['DialogIntentPredictionPipeline'], | ['DialogIntentPredictionPipeline'], | ||||
| 'dialog_modeling_pipeline': ['DialogModelingPipeline'], | 'dialog_modeling_pipeline': ['DialogModelingPipeline'], | ||||
| @@ -0,0 +1,284 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| from typing import Any, Dict, Union | |||||
| import torch | |||||
| from transformers import BertTokenizer | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.nlp import TableQuestionAnswering | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import TableQuestionAnsweringPreprocessor | |||||
| from modelscope.preprocessors.star3.fields.database import Database | |||||
| from modelscope.preprocessors.star3.fields.struct import Constant, SQLQuery | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| __all__ = ['TableQuestionAnsweringPipeline'] | |||||
| @PIPELINES.register_module( | |||||
| Tasks.table_question_answering, | |||||
| module_name=Pipelines.table_question_answering_pipeline) | |||||
| class TableQuestionAnsweringPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: Union[TableQuestionAnswering, str], | |||||
| preprocessor: TableQuestionAnsweringPreprocessor = None, | |||||
| db: Database = None, | |||||
| **kwargs): | |||||
| """use `model` and `preprocessor` to create a table question answering prediction pipeline | |||||
| Args: | |||||
| model (TableQuestionAnswering): a model instance | |||||
| preprocessor (TableQuestionAnsweringPreprocessor): a preprocessor instance | |||||
| db (Database): a database to store tables in the database | |||||
| """ | |||||
| model = model if isinstance( | |||||
| model, TableQuestionAnswering) else Model.from_pretrained(model) | |||||
| if preprocessor is None: | |||||
| preprocessor = TableQuestionAnsweringPreprocessor(model.model_dir) | |||||
| # initilize tokenizer | |||||
| self.tokenizer = BertTokenizer( | |||||
| os.path.join(model.model_dir, ModelFile.VOCAB_FILE)) | |||||
| # initialize database | |||||
| if db is None: | |||||
| self.db = Database( | |||||
| tokenizer=self.tokenizer, | |||||
| table_file_path=os.path.join(model.model_dir, 'table.json'), | |||||
| syn_dict_file_path=os.path.join(model.model_dir, | |||||
| 'synonym.txt')) | |||||
| else: | |||||
| self.db = db | |||||
| constant = Constant() | |||||
| self.agg_ops = constant.agg_ops | |||||
| self.cond_ops = constant.cond_ops | |||||
| self.cond_conn_ops = constant.cond_conn_ops | |||||
| self.action_ops = constant.action_ops | |||||
| self.max_select_num = constant.max_select_num | |||||
| self.max_where_num = constant.max_where_num | |||||
| self.col_type_dict = constant.col_type_dict | |||||
| self.schema_link_dict = constant.schema_link_dict | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| def post_process_multi_turn(self, history_sql, result, table): | |||||
| action = self.action_ops[result['action']] | |||||
| headers = table['header_name'] | |||||
| current_sql = result['sql'] | |||||
| if history_sql is None: | |||||
| return current_sql | |||||
| if action == 'out_of_scripts': | |||||
| return history_sql | |||||
| elif action == 'switch_table': | |||||
| return current_sql | |||||
| elif action == 'restart': | |||||
| return current_sql | |||||
| elif action == 'firstTurn': | |||||
| return current_sql | |||||
| elif action == 'del_focus': | |||||
| pre_final_sql = copy.deepcopy(history_sql) | |||||
| pre_sels = [] | |||||
| pre_aggs = [] | |||||
| for idx, seli in enumerate(pre_final_sql['sel']): | |||||
| if seli not in current_sql['sel']: | |||||
| pre_sels.append(seli) | |||||
| pre_aggs.append(pre_final_sql['agg'][idx]) | |||||
| if len(pre_sels) < 1: | |||||
| pre_sels.append(len(headers)) | |||||
| pre_aggs.append(0) | |||||
| pre_final_sql['sel'] = pre_sels | |||||
| pre_final_sql['agg'] = pre_aggs | |||||
| final_conds = [] | |||||
| for condi in pre_final_sql['conds']: | |||||
| if condi[0] < len(headers): | |||||
| final_conds.append(condi) | |||||
| if len(final_conds) < 1: | |||||
| final_conds.append([len(headers), 2, 'Null']) | |||||
| pre_final_sql['conds'] = final_conds | |||||
| return pre_final_sql | |||||
| elif action == 'change_agg_only': | |||||
| pre_final_sql = history_sql | |||||
| pre_sels = [] | |||||
| pre_aggs = [] | |||||
| for idx, seli in enumerate(pre_final_sql['sel']): | |||||
| if seli in current_sql['sel']: | |||||
| pre_sels.append(seli) | |||||
| changed_aggi = -1 | |||||
| for idx_single, aggi in enumerate(current_sql['agg']): | |||||
| if current_sql['sel'][idx_single] == seli: | |||||
| changed_aggi = aggi | |||||
| pre_aggs.append(changed_aggi) | |||||
| else: | |||||
| pre_sels.append(seli) | |||||
| pre_aggs.append(pre_final_sql['agg'][idx]) | |||||
| pre_final_sql['sel'] = pre_sels | |||||
| pre_final_sql['agg'] = pre_aggs | |||||
| return pre_final_sql | |||||
| elif action == 'change_focus_total': | |||||
| pre_final_sql = history_sql | |||||
| pre_sels = current_sql['sel'] | |||||
| pre_aggs = current_sql['agg'] | |||||
| pre_final_sql['sel'] = pre_sels | |||||
| pre_final_sql['agg'] = pre_aggs | |||||
| for pre_condi in current_sql['conds']: | |||||
| if pre_condi[0] < len(headers): | |||||
| in_flag = False | |||||
| for history_condi in history_sql['conds']: | |||||
| if pre_condi[0] == history_condi[0]: | |||||
| in_flag = True | |||||
| if not in_flag: | |||||
| pre_final_sql['conds'].append(pre_condi) | |||||
| return pre_final_sql | |||||
| elif action == 'del_cond': | |||||
| pre_final_sql = copy.deepcopy(history_sql) | |||||
| final_conds = [] | |||||
| for idx, condi in enumerate(pre_final_sql['conds']): | |||||
| if condi[0] not in current_sql['sel']: | |||||
| final_conds.append(condi) | |||||
| pre_final_sql['conds'] = final_conds | |||||
| final_conds = [] | |||||
| for condi in pre_final_sql['conds']: | |||||
| if condi[0] < len(headers): | |||||
| final_conds.append(condi) | |||||
| if len(final_conds) < 1: | |||||
| final_conds.append([len(headers), 2, 'Null']) | |||||
| pre_final_sql['conds'] = final_conds | |||||
| return pre_final_sql | |||||
| elif action == 'change_cond': | |||||
| pre_final_sql = history_sql | |||||
| final_conds = [] | |||||
| for idx, condi in enumerate(pre_final_sql['conds']): | |||||
| in_single_flag = False | |||||
| for single_condi in current_sql['conds']: | |||||
| if condi[0] == single_condi[0]: | |||||
| in_single_flag = True | |||||
| final_conds.append(single_condi) | |||||
| if not in_single_flag: | |||||
| final_conds.append(condi) | |||||
| pre_final_sql['conds'] = final_conds | |||||
| final_conds = [] | |||||
| for condi in pre_final_sql['conds']: | |||||
| if condi[0] < len(headers): | |||||
| final_conds.append(condi) | |||||
| if len(final_conds) < 1: | |||||
| final_conds.append([len(headers), 2, 'Null', 'Null']) | |||||
| pre_final_sql['conds'] = final_conds | |||||
| return pre_final_sql | |||||
| elif action == 'add_cond': | |||||
| pre_final_sql = history_sql | |||||
| final_conds = pre_final_sql['conds'] | |||||
| for idx, condi in enumerate(current_sql['conds']): | |||||
| if condi[0] < len(headers): | |||||
| final_conds.append(condi) | |||||
| pre_final_sql['conds'] = final_conds | |||||
| final_conds = [] | |||||
| for condi in pre_final_sql['conds']: | |||||
| if condi[0] < len(headers): | |||||
| final_conds.append(condi) | |||||
| if len(final_conds) < 1: | |||||
| final_conds.append([len(headers), 2, 'Null']) | |||||
| pre_final_sql['conds'] = final_conds | |||||
| return pre_final_sql | |||||
| else: | |||||
| return current_sql | |||||
| def sql_dict_to_str(self, result, table): | |||||
| """ | |||||
| convert sql struct to string | |||||
| """ | |||||
| header_names = table['header_name'] + ['空列'] | |||||
| header_ids = table['header_id'] + ['null'] | |||||
| sql = result['sql'] | |||||
| str_sel_list, sql_sel_list = [], [] | |||||
| for idx, sel in enumerate(sql['sel']): | |||||
| header_name = header_names[sel] | |||||
| header_id = '`%s`.`%s`' % (table['table_id'], header_ids[sel]) | |||||
| if sql['agg'][idx] == 0: | |||||
| str_sel_list.append(header_name) | |||||
| sql_sel_list.append(header_id) | |||||
| else: | |||||
| str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '( ' | |||||
| + header_name + ' )') | |||||
| sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '( ' | |||||
| + header_id + ' )') | |||||
| str_cond_list, sql_cond_list = [], [] | |||||
| for cond in sql['conds']: | |||||
| header_name = header_names[cond[0]] | |||||
| header_id = '`%s`.`%s`' % (table['table_id'], header_ids[cond[0]]) | |||||
| op = self.cond_ops[cond[1]] | |||||
| value = cond[2] | |||||
| str_cond_list.append('( ' + header_name + ' ' + op + ' "' + value | |||||
| + '" )') | |||||
| sql_cond_list.append('( ' + header_id + ' ' + op + ' "' + value | |||||
| + '" )') | |||||
| cond = ' ' + self.cond_conn_ops[sql['cond_conn_op']] + ' ' | |||||
| final_str = 'SELECT %s FROM %s WHERE %s' % (', '.join(str_sel_list), | |||||
| table['table_name'], | |||||
| cond.join(str_cond_list)) | |||||
| final_sql = 'SELECT %s FROM `%s` WHERE %s' % (', '.join(sql_sel_list), | |||||
| table['table_id'], | |||||
| cond.join(sql_cond_list)) | |||||
| sql = SQLQuery( | |||||
| string=final_str, query=final_sql, sql_result=result['sql']) | |||||
| return sql | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||||
| """process the prediction results | |||||
| Args: | |||||
| inputs (Dict[str, Any]): _description_ | |||||
| Returns: | |||||
| Dict[str, str]: the prediction results | |||||
| """ | |||||
| result = inputs['result'] | |||||
| history_sql = inputs['history_sql'] | |||||
| result['sql'] = self.post_process_multi_turn( | |||||
| history_sql=history_sql, | |||||
| result=result, | |||||
| table=self.db.tables[result['table_id']]) | |||||
| sql = self.sql_dict_to_str( | |||||
| result=result, table=self.db.tables[result['table_id']]) | |||||
| output = {OutputKeys.OUTPUT: sql, OutputKeys.HISTORY: result['sql']} | |||||
| return output | |||||
| def _collate_fn(self, data): | |||||
| return data | |||||
| @@ -30,6 +30,7 @@ if TYPE_CHECKING: | |||||
| DialogStateTrackingPreprocessor) | DialogStateTrackingPreprocessor) | ||||
| from .video import ReadVideoData, MovieSceneSegmentationPreprocessor | from .video import ReadVideoData, MovieSceneSegmentationPreprocessor | ||||
| from .star import ConversationalTextToSqlPreprocessor | from .star import ConversationalTextToSqlPreprocessor | ||||
| from .star3 import TableQuestionAnsweringPreprocessor | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| @@ -62,6 +63,7 @@ else: | |||||
| 'DialogStateTrackingPreprocessor', 'InputFeatures' | 'DialogStateTrackingPreprocessor', 'InputFeatures' | ||||
| ], | ], | ||||
| 'star': ['ConversationalTextToSqlPreprocessor'], | 'star': ['ConversationalTextToSqlPreprocessor'], | ||||
| 'star3': ['TableQuestionAnsweringPreprocessor'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,24 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .table_question_answering_preprocessor import TableQuestionAnsweringPreprocessor | |||||
| from .fields import MultiWOZBPETextField, IntentBPETextField | |||||
| else: | |||||
| _import_structure = { | |||||
| 'table_question_answering_preprocessor': | |||||
| ['TableQuestionAnsweringPreprocessor'], | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,77 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import json | |||||
| import tqdm | |||||
| from modelscope.preprocessors.star3.fields.struct import Trie | |||||
| class Database: | |||||
| def __init__(self, tokenizer, table_file_path, syn_dict_file_path): | |||||
| self.tokenizer = tokenizer | |||||
| self.tables = self.init_tables(table_file_path=table_file_path) | |||||
| self.syn_dict = self.init_syn_dict( | |||||
| syn_dict_file_path=syn_dict_file_path) | |||||
| def init_tables(self, table_file_path): | |||||
| tables = {} | |||||
| lines = [] | |||||
| with open(table_file_path, 'r') as fo: | |||||
| for line in fo: | |||||
| lines.append(line) | |||||
| for line in tqdm.tqdm(lines, desc='Load Tables'): | |||||
| table = json.loads(line.strip()) | |||||
| table_header_length = 0 | |||||
| headers_tokens = [] | |||||
| for header in table['header_name']: | |||||
| header_tokens = self.tokenizer.tokenize(header) | |||||
| table_header_length += len(header_tokens) | |||||
| headers_tokens.append(header_tokens) | |||||
| empty_column = self.tokenizer.tokenize('空列') | |||||
| table_header_length += len(empty_column) | |||||
| headers_tokens.append(empty_column) | |||||
| table['tablelen'] = table_header_length | |||||
| table['header_tok'] = headers_tokens | |||||
| table['header_types'].append('null') | |||||
| table['header_units'] = [ | |||||
| self.tokenizer.tokenize(unit) for unit in table['header_units'] | |||||
| ] + [[]] | |||||
| trie_set = [Trie() for _ in table['header_name']] | |||||
| for row in table['rows']: | |||||
| for ii, cell in enumerate(row): | |||||
| if 'real' in table['header_types'][ii].lower() or \ | |||||
| 'number' in table['header_types'][ii].lower() or \ | |||||
| 'duration' in table['header_types'][ii].lower(): | |||||
| continue | |||||
| word = str(cell).strip().lower() | |||||
| trie_set[ii].insert(word, word) | |||||
| table['value_trie'] = trie_set | |||||
| tables[table['table_id']] = table | |||||
| return tables | |||||
| def init_syn_dict(self, syn_dict_file_path): | |||||
| lines = [] | |||||
| with open(syn_dict_file_path, encoding='utf-8') as fo: | |||||
| for line in fo: | |||||
| lines.append(line) | |||||
| syn_dict = {} | |||||
| for line in tqdm.tqdm(lines, desc='Load Synonym Dict'): | |||||
| tokens = line.strip().split('\t') | |||||
| if len(tokens) != 2: | |||||
| continue | |||||
| keys = tokens[0].strip().split('|') | |||||
| values = tokens[1].strip().split('|') | |||||
| for key in keys: | |||||
| key = key.lower().strip() | |||||
| syn_dict.setdefault(key, []) | |||||
| for value in values: | |||||
| syn_dict[key].append(value.lower().strip()) | |||||
| return syn_dict | |||||
| @@ -0,0 +1,423 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import re | |||||
| from modelscope.preprocessors.star3.fields.struct import TypeInfo | |||||
| class SchemaLinker: | |||||
| def __init__(self): | |||||
| pass | |||||
| def find_in_list(self, comlist, words): | |||||
| result = False | |||||
| for com in comlist: | |||||
| if words in com: | |||||
| result = True | |||||
| break | |||||
| return result | |||||
| def get_continue_score(self, pstr, tstr): | |||||
| comlist = [] | |||||
| minlen = min(len(pstr), len(tstr)) | |||||
| for slen in range(minlen, 1, -1): | |||||
| for ts in range(0, len(tstr), 1): | |||||
| if ts + slen > len(tstr): | |||||
| continue | |||||
| words = tstr[ts:ts + slen] | |||||
| if words in pstr and not self.find_in_list(comlist, words): | |||||
| comlist.append(words) | |||||
| comlen = 0 | |||||
| for com in comlist: | |||||
| comlen += len(com) * len(com) | |||||
| weight = comlen / (len(tstr) * len(tstr) + 0.001) | |||||
| if weight > 1.0: | |||||
| weight = 1.0 | |||||
| return weight | |||||
| def get_match_score(self, ptokens, ttokens): | |||||
| pset = set(ptokens) | |||||
| tset = set(ttokens) | |||||
| comset = pset & tset | |||||
| allset = pset | tset | |||||
| weight2 = len(comset) / (len(allset) + 0.001) | |||||
| weight3 = self.get_continue_score(''.join(ptokens), ''.join(ttokens)) | |||||
| return 0.4 * weight2 + 0.6 * weight3 | |||||
| def is_number(self, s): | |||||
| try: | |||||
| float(s) | |||||
| return True | |||||
| except ValueError: | |||||
| pass | |||||
| try: | |||||
| import unicodedata | |||||
| unicodedata.numeric(s) | |||||
| return True | |||||
| except (TypeError, ValueError): | |||||
| pass | |||||
| return False | |||||
| def get_match_phrase(self, query, target): | |||||
| if target in query: | |||||
| return target, 1.0 | |||||
| qtokens = [] | |||||
| for i in range(0, len(query), 1): | |||||
| qtokens.append(query[i:i + 1]) | |||||
| ttokens = [] | |||||
| for i in range(0, len(target), 1): | |||||
| ttokens.append(target[i:i + 1]) | |||||
| ttok_set = set(ttokens) | |||||
| phrase = '' | |||||
| score = 0.0 | |||||
| for qidx, qword in enumerate(qtokens): | |||||
| if qword not in ttok_set: | |||||
| continue | |||||
| eidx = (qidx + 2 * len(ttokens)) if ( | |||||
| len(qtokens) > qidx + 2 * len(ttokens)) else len(qtokens) | |||||
| while eidx > qidx: | |||||
| ptokens = qtokens[qidx:eidx] | |||||
| weight = self.get_match_score(ptokens, ttokens) | |||||
| if weight + 0.001 > score: | |||||
| score = weight | |||||
| phrase = ''.join(ptokens) | |||||
| eidx -= 1 | |||||
| if self.is_number(target) and phrase != target: | |||||
| score = 0.0 | |||||
| if len(phrase) > 1 and phrase in target: | |||||
| score *= (1.0 + 0.05 * len(phrase)) | |||||
| return phrase, score | |||||
| def allfindpairidx(self, que_tok, value_tok, weight): | |||||
| idxs = [] | |||||
| for i in range(0, len(que_tok) - len(value_tok) + 1, 1): | |||||
| s = i | |||||
| e = i | |||||
| matched = True | |||||
| for j in range(0, len(value_tok), 1): | |||||
| if value_tok[j].lower() == que_tok[i + j].lower(): | |||||
| e = i + j | |||||
| else: | |||||
| matched = False | |||||
| break | |||||
| if matched: | |||||
| idxs.append([s, e, weight]) | |||||
| return idxs | |||||
| def findnear(self, ps1, pe1, ps2, pe2): | |||||
| if abs(ps1 - pe2) <= 2 or abs(pe1 - ps2) <= 2: | |||||
| return True | |||||
| return False | |||||
| def get_column_type(self, col_idx, table): | |||||
| colType = table['header_types'][col_idx] | |||||
| if 'number' in colType or 'duration' in colType or 'real' in colType: | |||||
| colType = 'real' | |||||
| elif 'date' in colType: | |||||
| colType = 'date' | |||||
| elif 'bool' in colType: | |||||
| colType = 'bool' | |||||
| else: | |||||
| colType = 'text' | |||||
| return colType | |||||
| def add_type_all(self, typeinfos, index, idxs, label, linktype, value, | |||||
| orgvalue): | |||||
| for idx in idxs: | |||||
| info = TypeInfo(label, index, linktype, value, orgvalue, idx[0], | |||||
| idx[1], idx[2]) | |||||
| flag = True | |||||
| for i, typeinfo in enumerate(typeinfos): | |||||
| if info.pstart < typeinfo.pstart: | |||||
| typeinfos.insert(i, info) | |||||
| flag = False | |||||
| break | |||||
| if flag: | |||||
| typeinfos.append(info) | |||||
| return typeinfos | |||||
| def save_info(self, tinfo, sinfo): | |||||
| flag = True | |||||
| if tinfo.pstart > sinfo.pend or tinfo.pend < sinfo.pstart: | |||||
| pass | |||||
| elif tinfo.pstart >= sinfo.pstart and \ | |||||
| tinfo.pend <= sinfo.pend and tinfo.index == -1: | |||||
| flag = False | |||||
| elif tinfo.pstart == sinfo.pstart and sinfo.pend == tinfo.pend and \ | |||||
| abs(tinfo.weight - sinfo.weight) < 0.01: | |||||
| pass | |||||
| else: | |||||
| if sinfo.label == 'col' or sinfo.label == 'val': | |||||
| if tinfo.label == 'col' or tinfo.label == 'val': | |||||
| if (sinfo.pend | |||||
| - sinfo.pstart) > (tinfo.pend - tinfo.pstart) or ( | |||||
| sinfo.weight > tinfo.weight | |||||
| and sinfo.index != -1): | |||||
| flag = False | |||||
| else: | |||||
| flag = False | |||||
| else: | |||||
| if (tinfo.label == 'op' or tinfo.label == 'agg'): | |||||
| if (sinfo.pend - sinfo.pstart) > ( | |||||
| tinfo.pend | |||||
| - tinfo.pstart) or sinfo.weight > tinfo.weight: | |||||
| flag = False | |||||
| return flag | |||||
| def normal_type_infos(self, infos): | |||||
| typeinfos = [] | |||||
| for info in infos: | |||||
| typeinfos = [x for x in typeinfos if self.save_info(x, info)] | |||||
| flag = True | |||||
| for i, typeinfo in enumerate(typeinfos): | |||||
| if not self.save_info(info, typeinfo): | |||||
| flag = False | |||||
| break | |||||
| if info.pstart < typeinfo.pstart: | |||||
| typeinfos.insert(i, info) | |||||
| flag = False | |||||
| break | |||||
| if flag: | |||||
| typeinfos.append(info) | |||||
| return typeinfos | |||||
| def findnear_typeinfo(self, info1, info2): | |||||
| return self.findnear(info1.pstart, info1.pend, info2.pstart, | |||||
| info2.pend) | |||||
| def find_real_column(self, infos, table): | |||||
| for i, vinfo in enumerate(infos): | |||||
| if vinfo.index != -1 or vinfo.label != 'val': | |||||
| continue | |||||
| eoidx = -1 | |||||
| for j, oinfo in enumerate(infos): | |||||
| if oinfo.label != 'op': | |||||
| continue | |||||
| if self.findnear_typeinfo(vinfo, oinfo): | |||||
| eoidx = j | |||||
| break | |||||
| for j, cinfo in enumerate(infos): | |||||
| if cinfo.label != 'col' or table['header_types'][ | |||||
| cinfo.index] != 'real': | |||||
| continue | |||||
| if self.findnear_typeinfo(cinfo, vinfo) or ( | |||||
| eoidx != -1 | |||||
| and self.findnear_typeinfo(cinfo, infos[eoidx])): | |||||
| infos[i].index = cinfo.index | |||||
| break | |||||
| return infos | |||||
| def filter_column_infos(self, infos): | |||||
| delid = [] | |||||
| for i, info in enumerate(infos): | |||||
| if info.label != 'col': | |||||
| continue | |||||
| for j in range(i + 1, len(infos), 1): | |||||
| if infos[j].label == 'col' and \ | |||||
| info.pstart == infos[j].pstart and \ | |||||
| info.pend == infos[j].pend: | |||||
| delid.append(i) | |||||
| delid.append(j) | |||||
| break | |||||
| typeinfos = [] | |||||
| for idx, info in enumerate(infos): | |||||
| if idx in set(delid): | |||||
| continue | |||||
| typeinfos.append(info) | |||||
| return typeinfos | |||||
| def filter_type_infos(self, infos, table): | |||||
| infos = self.filter_column_infos(infos) | |||||
| infos = self.find_real_column(infos, table) | |||||
| colvalMp = {} | |||||
| for info in infos: | |||||
| if info.label == 'col': | |||||
| colvalMp[info.index] = [] | |||||
| for info in infos: | |||||
| if info.label == 'val' and info.index in colvalMp: | |||||
| colvalMp[info.index].append(info) | |||||
| delid = [] | |||||
| for idx, info in enumerate(infos): | |||||
| if info.label != 'val' or info.index in colvalMp: | |||||
| continue | |||||
| for index in colvalMp.keys(): | |||||
| valinfos = colvalMp[index] | |||||
| for valinfo in valinfos: | |||||
| if valinfo.pstart <= info.pstart and \ | |||||
| valinfo.pend >= info.pend: | |||||
| delid.append(idx) | |||||
| break | |||||
| typeinfos = [] | |||||
| for idx, info in enumerate(infos): | |||||
| if idx in set(delid): | |||||
| continue | |||||
| typeinfos.append(info) | |||||
| return typeinfos | |||||
| def get_table_match_score(self, nlu_t, schema_link): | |||||
| match_len = 0 | |||||
| for info in schema_link: | |||||
| scale = 0.6 | |||||
| if info['question_len'] > 0 and info['column_index'] != -1: | |||||
| scale = 1.0 | |||||
| else: | |||||
| scale = 0.5 | |||||
| match_len += scale * info['question_len'] * info['weight'] | |||||
| return match_len / (len(nlu_t) + 0.1) | |||||
| def get_entity_linking(self, tokenizer, nlu, nlu_t, tables, col_syn_dict): | |||||
| """ | |||||
| get linking between question and schema column | |||||
| """ | |||||
| typeinfos = [] | |||||
| numbers = re.findall(r'[-]?\d*\.\d+|[-]?\d+|\d+', nlu) | |||||
| # search schema link in every table | |||||
| search_result_list = [] | |||||
| for tablename in tables: | |||||
| table = tables[tablename] | |||||
| trie_set = None | |||||
| if 'value_trie' in table: | |||||
| trie_set = table['value_trie'] | |||||
| typeinfos = [] | |||||
| for ii, column in enumerate(table['header_name']): | |||||
| column = column.lower() | |||||
| column_new = re.sub('(.*?)', '', column) | |||||
| column_new = re.sub('(.*?)', '', column_new) | |||||
| cphrase, cscore = self.get_match_phrase( | |||||
| nlu.lower(), column_new) | |||||
| if cscore > 0.3 and cphrase.strip() != '': | |||||
| phrase_tok = tokenizer.tokenize(cphrase) | |||||
| cidxs = self.allfindpairidx(nlu_t, phrase_tok, cscore) | |||||
| typeinfos = self.add_type_all(typeinfos, ii, cidxs, 'col', | |||||
| 'column', cphrase, column) | |||||
| if cscore < 0.8 and column_new in col_syn_dict: | |||||
| columns = list(set(col_syn_dict[column_new])) | |||||
| for syn_col in columns: | |||||
| if syn_col not in nlu.lower() or syn_col == '': | |||||
| continue | |||||
| phrase_tok = tokenizer.tokenize(syn_col) | |||||
| cidxs = self.allfindpairidx(nlu_t, phrase_tok, 1.0) | |||||
| typeinfos = self.add_type_all(typeinfos, ii, cidxs, | |||||
| 'col', 'column', syn_col, | |||||
| column) | |||||
| for ii, trie in enumerate(trie_set): | |||||
| ans = trie.match(nlu.lower()) | |||||
| for cell in ans.keys(): | |||||
| vphrase = cell | |||||
| vscore = 1.0 | |||||
| # print("trie_set find:", cell, ans[cell]) | |||||
| phrase_tok = tokenizer.tokenize(vphrase) | |||||
| if len(phrase_tok) == 0 or len(vphrase) < 2: | |||||
| continue | |||||
| vidxs = self.allfindpairidx(nlu_t, phrase_tok, vscore) | |||||
| linktype = self.get_column_type(ii, table) | |||||
| typeinfos = self.add_type_all(typeinfos, ii, vidxs, 'val', | |||||
| linktype, vphrase, ans[cell]) | |||||
| for number in set(numbers): | |||||
| number_tok = tokenizer.tokenize(number.lower()) | |||||
| if len(number_tok) == 0: | |||||
| continue | |||||
| nidxs = self.allfindpairidx(nlu_t, number_tok, 1.0) | |||||
| typeinfos = self.add_type_all(typeinfos, -1, nidxs, 'val', | |||||
| 'real', number, number) | |||||
| newtypeinfos = self.normal_type_infos(typeinfos) | |||||
| newtypeinfos = self.filter_type_infos(newtypeinfos, table) | |||||
| final_question = [0] * len(nlu_t) | |||||
| final_header = [0] * len(table['header_name']) | |||||
| for typeinfo in newtypeinfos: | |||||
| pstart = typeinfo.pstart | |||||
| pend = typeinfo.pend + 1 | |||||
| if typeinfo.label == 'op' or typeinfo.label == 'agg': | |||||
| score = int(typeinfo.linktype[-1]) | |||||
| if typeinfo.label == 'op': | |||||
| score += 6 | |||||
| else: | |||||
| score += 11 | |||||
| for i in range(pstart, pend, 1): | |||||
| final_question[i] = score | |||||
| elif typeinfo.label == 'col': | |||||
| for i in range(pstart, pend, 1): | |||||
| final_question[i] = 4 | |||||
| if final_header[typeinfo.index] % 2 == 0: | |||||
| final_header[typeinfo.index] += 1 | |||||
| elif typeinfo.label == 'val': | |||||
| if typeinfo.index == -1: | |||||
| for i in range(pstart, pend, 1): | |||||
| final_question[i] = 5 | |||||
| else: | |||||
| for i in range(pstart, pend, 1): | |||||
| final_question[i] = 2 | |||||
| final_question[pstart] = 1 | |||||
| final_question[pend - 1] = 3 | |||||
| if final_header[typeinfo.index] < 2: | |||||
| final_header[typeinfo.index] += 2 | |||||
| # collect schema_link | |||||
| schema_link = [] | |||||
| for sl in newtypeinfos: | |||||
| if sl.label in ['val', 'col']: | |||||
| schema_link.append({ | |||||
| 'question_len': | |||||
| max(0, sl.pend - sl.pstart + 1), | |||||
| 'question_index': [sl.pstart, sl.pend], | |||||
| 'question_span': | |||||
| ''.join(nlu_t[sl.pstart:sl.pend + 1]), | |||||
| 'column_index': | |||||
| sl.index, | |||||
| 'column_span': | |||||
| table['header_name'][sl.index] | |||||
| if sl.index != -1 else '空列', | |||||
| 'label': | |||||
| sl.label, | |||||
| 'weight': | |||||
| round(sl.weight, 4) | |||||
| }) | |||||
| # get the match score of each table | |||||
| match_score = self.get_table_match_score(nlu_t, schema_link) | |||||
| search_result = { | |||||
| 'table_id': table['table_id'], | |||||
| 'question_knowledge': final_question, | |||||
| 'header_knowledge': final_header, | |||||
| 'schema_link': schema_link, | |||||
| 'match_score': match_score | |||||
| } | |||||
| search_result_list.append(search_result) | |||||
| search_result_list = sorted( | |||||
| search_result_list, key=lambda x: x['match_score'], | |||||
| reverse=True)[0:4] | |||||
| return search_result_list | |||||
| @@ -0,0 +1,181 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| cond_ops = ['>', '<', '==', '!=', 'ASC', 'DESC'] | |||||
| agg_ops = [ | |||||
| '', 'AVG', 'MAX', 'MIN', 'COUNT', 'SUM', 'COMPARE', 'GROUP BY', 'SAME' | |||||
| ] | |||||
| conn_ops = ['', 'AND', 'OR'] | |||||
| class Context: | |||||
| def __init__(self): | |||||
| self.history_sql = None | |||||
| def set_history_sql(self, sql): | |||||
| self.history_sql = sql | |||||
| class SQLQuery: | |||||
| def __init__(self, string, query, sql_result): | |||||
| self.string = string | |||||
| self.query = query | |||||
| self.sql_result = sql_result | |||||
| class TrieNode(object): | |||||
| def __init__(self): | |||||
| """ | |||||
| Initialize your data structure here. | |||||
| """ | |||||
| self.data = {} | |||||
| self.is_word = False | |||||
| self.term = None | |||||
| class Trie(object): | |||||
| def __init__(self): | |||||
| self.root = TrieNode() | |||||
| def insert(self, word, term): | |||||
| """ | |||||
| Inserts a word into the trie. | |||||
| :type word: str | |||||
| :rtype: void | |||||
| """ | |||||
| node = self.root | |||||
| for letter in word: | |||||
| child = node.data.get(letter) | |||||
| if not child: | |||||
| node.data[letter] = TrieNode() | |||||
| node = node.data[letter] | |||||
| node.is_word = True | |||||
| node.term = term | |||||
| def search(self, word): | |||||
| """ | |||||
| Returns if the word is in the trie. | |||||
| :type word: str | |||||
| :rtype: bool | |||||
| """ | |||||
| node = self.root | |||||
| for letter in word: | |||||
| node = node.data.get(letter) | |||||
| if not node: | |||||
| return None, False | |||||
| return node.term, True | |||||
| def match(self, query): | |||||
| start = 0 | |||||
| end = 1 | |||||
| length = len(query) | |||||
| ans = {} | |||||
| while start < length and end < length: | |||||
| sub = query[start:end] | |||||
| term, flag = self.search(sub) | |||||
| if flag: | |||||
| if term is not None: | |||||
| ans[sub] = term | |||||
| end += 1 | |||||
| else: | |||||
| start += 1 | |||||
| end = start + 1 | |||||
| return ans | |||||
| def starts_with(self, prefix): | |||||
| """ | |||||
| Returns if there is any word in the trie | |||||
| that starts with the given prefix. | |||||
| :type prefix: str | |||||
| :rtype: bool | |||||
| """ | |||||
| node = self.root | |||||
| for letter in prefix: | |||||
| node = node.data.get(letter) | |||||
| if not node: | |||||
| return False | |||||
| return True | |||||
| def get_start(self, prefix): | |||||
| """ | |||||
| Returns words started with prefix | |||||
| :param prefix: | |||||
| :return: words (list) | |||||
| """ | |||||
| def _get_key(pre, pre_node): | |||||
| words_list = [] | |||||
| if pre_node.is_word: | |||||
| words_list.append(pre) | |||||
| for x in pre_node.data.keys(): | |||||
| words_list.extend(_get_key(pre + str(x), pre_node.data.get(x))) | |||||
| return words_list | |||||
| words = [] | |||||
| if not self.starts_with(prefix): | |||||
| return words | |||||
| if self.search(prefix): | |||||
| words.append(prefix) | |||||
| return words | |||||
| node = self.root | |||||
| for letter in prefix: | |||||
| node = node.data.get(letter) | |||||
| return _get_key(prefix, node) | |||||
| class TypeInfo: | |||||
| def __init__(self, label, index, linktype, value, orgvalue, pstart, pend, | |||||
| weight): | |||||
| self.label = label | |||||
| self.index = index | |||||
| self.linktype = linktype | |||||
| self.value = value | |||||
| self.orgvalue = orgvalue | |||||
| self.pstart = pstart | |||||
| self.pend = pend | |||||
| self.weight = weight | |||||
| class Constant: | |||||
| def __init__(self): | |||||
| self.action_ops = [ | |||||
| 'add_cond', 'change_cond', 'del_cond', 'change_focus_total', | |||||
| 'change_agg_only', 'del_focus', 'restart', 'switch_table', | |||||
| 'out_of_scripts', 'repeat', 'firstTurn' | |||||
| ] | |||||
| self.agg_ops = [ | |||||
| '', 'AVG', 'MAX', 'MIN', 'COUNT', 'SUM', 'COMPARE', 'GROUP BY', | |||||
| 'SAME' | |||||
| ] | |||||
| self.cond_ops = ['>', '<', '==', '!=', 'ASC', 'DESC'] | |||||
| self.cond_conn_ops = ['', 'AND', 'OR'] | |||||
| self.col_type_dict = { | |||||
| 'null': 0, | |||||
| 'text': 1, | |||||
| 'number': 2, | |||||
| 'duration': 3, | |||||
| 'bool': 4, | |||||
| 'date': 5 | |||||
| } | |||||
| self.schema_link_dict = { | |||||
| 'col_start': 1, | |||||
| 'col_middle': 2, | |||||
| 'col_end': 3, | |||||
| 'val_start': 4, | |||||
| 'val_middle': 5, | |||||
| 'val_end': 6 | |||||
| } | |||||
| self.max_select_num = 4 | |||||
| self.max_where_num = 6 | |||||
| @@ -0,0 +1,118 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| from typing import Any, Dict | |||||
| import torch | |||||
| from transformers import BertTokenizer | |||||
| from modelscope.metainfo import Preprocessors | |||||
| from modelscope.preprocessors.base import Preprocessor | |||||
| from modelscope.preprocessors.builder import PREPROCESSORS | |||||
| from modelscope.preprocessors.star3.fields.database import Database | |||||
| from modelscope.preprocessors.star3.fields.schema_link import SchemaLinker | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import Fields, ModelFile | |||||
| from modelscope.utils.type_assert import type_assert | |||||
| __all__ = ['TableQuestionAnsweringPreprocessor'] | |||||
| @PREPROCESSORS.register_module( | |||||
| Fields.nlp, | |||||
| module_name=Preprocessors.table_question_answering_preprocessor) | |||||
| class TableQuestionAnsweringPreprocessor(Preprocessor): | |||||
| def __init__(self, model_dir: str, db: Database = None, *args, **kwargs): | |||||
| """preprocess the data | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| db (Database): database instance | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| self.model_dir: str = model_dir | |||||
| self.config = Config.from_file( | |||||
| os.path.join(self.model_dir, ModelFile.CONFIGURATION)) | |||||
| # read tokenizer | |||||
| self.tokenizer = BertTokenizer( | |||||
| os.path.join(self.model_dir, ModelFile.VOCAB_FILE)) | |||||
| # read database | |||||
| if db is None: | |||||
| self.db = Database( | |||||
| tokenizer=self.tokenizer, | |||||
| table_file_path=os.path.join(self.model_dir, 'table.json'), | |||||
| syn_dict_file_path=os.path.join(self.model_dir, 'synonym.txt')) | |||||
| else: | |||||
| self.db = db | |||||
| # get schema linker | |||||
| self.schema_linker = SchemaLinker() | |||||
| # set device | |||||
| self.device = 'cuda' if \ | |||||
| ('device' not in kwargs or kwargs['device'] == 'gpu') \ | |||||
| and torch.cuda.is_available() else 'cpu' | |||||
| def construct_data(self, search_result_list, nlu, nlu_t, db, history_sql): | |||||
| datas = [] | |||||
| for search_result in search_result_list: | |||||
| data = {} | |||||
| data['table_id'] = search_result['table_id'] | |||||
| data['question'] = nlu | |||||
| data['question_tok'] = nlu_t | |||||
| data['header_tok'] = db.tables[data['table_id']]['header_tok'] | |||||
| data['types'] = db.tables[data['table_id']]['header_types'] | |||||
| data['units'] = db.tables[data['table_id']]['header_units'] | |||||
| data['action'] = 0 | |||||
| data['sql'] = None | |||||
| data['history_sql'] = history_sql | |||||
| data['wvi_corenlp'] = [] | |||||
| data['bertindex_knowledge'] = search_result['question_knowledge'] | |||||
| data['header_knowledge'] = search_result['header_knowledge'] | |||||
| data['schema_link'] = search_result['schema_link'] | |||||
| datas.append(data) | |||||
| return datas | |||||
| @type_assert(object, dict) | |||||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """process the raw input data | |||||
| Args: | |||||
| data (dict): | |||||
| utterance: a sentence | |||||
| last_sql: predicted sql of last utterance | |||||
| Example: | |||||
| utterance: 'Which of these are hiring?' | |||||
| last_sql: '' | |||||
| Returns: | |||||
| Dict[str, Any]: the preprocessed data | |||||
| """ | |||||
| # tokenize question | |||||
| question = data['question'] | |||||
| history_sql = data['history_sql'] | |||||
| nlu = question.lower() | |||||
| nlu_t = self.tokenizer.tokenize(nlu) | |||||
| # get linking | |||||
| search_result_list = self.schema_linker.get_entity_linking( | |||||
| tokenizer=self.tokenizer, | |||||
| nlu=nlu, | |||||
| nlu_t=nlu_t, | |||||
| tables=self.db.tables, | |||||
| col_syn_dict=self.db.syn_dict) | |||||
| # collect data | |||||
| datas = self.construct_data( | |||||
| search_result_list=search_result_list[0:1], | |||||
| nlu=nlu, | |||||
| nlu_t=nlu_t, | |||||
| db=self.db, | |||||
| history_sql=history_sql) | |||||
| return {'datas': datas} | |||||
| @@ -3,7 +3,8 @@ from typing import List | |||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.pipelines.nlp import (ConversationalTextToSqlPipeline, | from modelscope.pipelines.nlp import (ConversationalTextToSqlPipeline, | ||||
| DialogStateTrackingPipeline) | |||||
| DialogStateTrackingPipeline, | |||||
| TableQuestionAnsweringPipeline) | |||||
| def text2sql_tracking_and_print_results( | def text2sql_tracking_and_print_results( | ||||
| @@ -42,3 +43,17 @@ def tracking_and_print_dialog_states( | |||||
| print(json.dumps(result)) | print(json.dumps(result)) | ||||
| history_states.extend([result[OutputKeys.OUTPUT], {}]) | history_states.extend([result[OutputKeys.OUTPUT], {}]) | ||||
| def tableqa_tracking_and_print_results( | |||||
| test_case, pipelines: List[TableQuestionAnsweringPipeline]): | |||||
| for pipeline in pipelines: | |||||
| historical_queries = None | |||||
| for question in test_case['utterance']: | |||||
| output_dict = pipeline({ | |||||
| 'question': question, | |||||
| 'history_sql': historical_queries | |||||
| }) | |||||
| print('output_dict', output_dict['output'].string, | |||||
| output_dict['output'].query) | |||||
| historical_queries = output_dict['history'] | |||||
| @@ -0,0 +1,76 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import unittest | |||||
| from typing import List | |||||
| from transformers import BertTokenizer | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models import Model | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline | |||||
| from modelscope.preprocessors import TableQuestionAnsweringPreprocessor | |||||
| from modelscope.preprocessors.star3.fields.database import Database | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.nlp.nlp_utils import tableqa_tracking_and_print_results | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class TableQuestionAnswering(unittest.TestCase): | |||||
| def setUp(self) -> None: | |||||
| self.task = Tasks.table_question_answering | |||||
| self.model_id = 'damo/nlp_convai_text2sql_pretrain_cn' | |||||
| model_id = 'damo/nlp_convai_text2sql_pretrain_cn' | |||||
| test_case = { | |||||
| 'utterance': | |||||
| ['长江流域的小(2)型水库的库容总量是多少?', '那平均值是多少?', '那水库的名称呢?', '换成中型的呢?'] | |||||
| } | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_by_direct_model_download(self): | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| preprocessor = TableQuestionAnsweringPreprocessor(model_dir=cache_path) | |||||
| pipelines = [ | |||||
| TableQuestionAnsweringPipeline( | |||||
| model=cache_path, preprocessor=preprocessor) | |||||
| ] | |||||
| tableqa_tracking_and_print_results(self.test_case, pipelines) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| preprocessor = TableQuestionAnsweringPreprocessor( | |||||
| model_dir=model.model_dir) | |||||
| pipelines = [ | |||||
| TableQuestionAnsweringPipeline( | |||||
| model=model, preprocessor=preprocessor) | |||||
| ] | |||||
| tableqa_tracking_and_print_results(self.test_case, pipelines) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_from_task(self): | |||||
| pipelines = [pipeline(Tasks.table_question_answering, self.model_id)] | |||||
| tableqa_tracking_and_print_results(self.test_case, pipelines) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_model_from_modelhub_with_other_classes(self): | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| self.tokenizer = BertTokenizer( | |||||
| os.path.join(model.model_dir, ModelFile.VOCAB_FILE)) | |||||
| db = Database( | |||||
| tokenizer=self.tokenizer, | |||||
| table_file_path=os.path.join(model.model_dir, 'table.json'), | |||||
| syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt')) | |||||
| preprocessor = TableQuestionAnsweringPreprocessor( | |||||
| model_dir=model.model_dir, db=db) | |||||
| pipelines = [ | |||||
| TableQuestionAnsweringPipeline( | |||||
| model=model, preprocessor=preprocessor, db=db) | |||||
| ] | |||||
| tableqa_tracking_and_print_results(self.test_case, pipelines) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||