commit nlp_convai_text2sql_pretrain_cn inference process to modelscope
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10025155
master
| @@ -55,6 +55,7 @@ class Models(object): | |||
| space_intent = 'space-intent' | |||
| space_modeling = 'space-modeling' | |||
| star = 'star' | |||
| star3 = 'star3' | |||
| tcrf = 'transformer-crf' | |||
| transformer_softmax = 'transformer-softmax' | |||
| lcrf = 'lstm-crf' | |||
| @@ -193,6 +194,7 @@ class Pipelines(object): | |||
| plug_generation = 'plug-generation' | |||
| faq_question_answering = 'faq-question-answering' | |||
| conversational_text_to_sql = 'conversational-text-to-sql' | |||
| table_question_answering_pipeline = 'table-question-answering-pipeline' | |||
| sentence_embedding = 'sentence-embedding' | |||
| passage_ranking = 'passage-ranking' | |||
| relation_extraction = 'relation-extraction' | |||
| @@ -296,6 +298,7 @@ class Preprocessors(object): | |||
| fill_mask_ponet = 'fill-mask-ponet' | |||
| faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' | |||
| conversational_text_to_sql = 'conversational-text-to-sql' | |||
| table_question_answering_preprocessor = 'table-question-answering-preprocessor' | |||
| re_tokenizer = 're-tokenizer' | |||
| document_segmentation = 'document-segmentation' | |||
| @@ -24,6 +24,7 @@ if TYPE_CHECKING: | |||
| from .space import SpaceForDialogIntent | |||
| from .space import SpaceForDialogModeling | |||
| from .space import SpaceForDialogStateTracking | |||
| from .table_question_answering import TableQuestionAnswering | |||
| from .task_models import (InformationExtractionModel, | |||
| SequenceClassificationModel, | |||
| SingleBackboneTaskModelBase, | |||
| @@ -64,6 +65,7 @@ else: | |||
| 'SingleBackboneTaskModelBase', 'TokenClassificationModel' | |||
| ], | |||
| 'token_classification': ['SbertForTokenClassification'], | |||
| 'table_question_answering': ['TableQuestionAnswering'], | |||
| 'sentence_embedding': ['SentenceEmbedding'], | |||
| 'passage_ranking': ['PassageRanking'], | |||
| } | |||
| @@ -0,0 +1,128 @@ | |||
| # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. | |||
| # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||
| # Copyright 2021-2022 The Alibaba DAMO Team Authors. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """PyTorch BERT configuration.""" | |||
| from __future__ import absolute_import, division, print_function | |||
| import copy | |||
| import logging | |||
| import math | |||
| import os | |||
| import shutil | |||
| import tarfile | |||
| import tempfile | |||
| from pathlib import Path | |||
| from typing import Union | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| import torch_scatter | |||
| from icecream import ic | |||
| from torch import nn | |||
| from torch.nn import CrossEntropyLoss | |||
| logger = logging.getLogger(__name__) | |||
| class Star3Config(object): | |||
| """Configuration class to store the configuration of a `Star3Model`. | |||
| """ | |||
| def __init__(self, | |||
| vocab_size_or_config_json_file, | |||
| hidden_size=768, | |||
| num_hidden_layers=12, | |||
| num_attention_heads=12, | |||
| intermediate_size=3072, | |||
| hidden_act='gelu', | |||
| hidden_dropout_prob=0.1, | |||
| attention_probs_dropout_prob=0.1, | |||
| max_position_embeddings=512, | |||
| type_vocab_size=2, | |||
| initializer_range=0.02): | |||
| """Constructs Star3Config. | |||
| Args: | |||
| vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `Star3Model`. | |||
| hidden_size: Size of the encoder layers and the pooler layer. | |||
| num_hidden_layers: Number of hidden layers in the Transformer encoder. | |||
| num_attention_heads: Number of attention heads for each attention layer in | |||
| the Transformer encoder. | |||
| intermediate_size: The size of the "intermediate" (i.e., feed-forward) | |||
| layer in the Transformer encoder. | |||
| hidden_act: The non-linear activation function (function or string) in the | |||
| encoder and pooler. If string, "gelu", "relu" and "swish" are supported. | |||
| hidden_dropout_prob: The dropout probabilitiy for all fully connected | |||
| layers in the embeddings, encoder, and pooler. | |||
| attention_probs_dropout_prob: The dropout ratio for the attention | |||
| probabilities. | |||
| max_position_embeddings: The maximum sequence length that this model might | |||
| ever be used with. Typically set this to something large just in case | |||
| (e.g., 512 or 1024 or 2048). | |||
| type_vocab_size: The vocabulary size of the `token_type_ids` passed into `Star3Model`. | |||
| initializer_range: The sttdev of the truncated_normal_initializer for | |||
| initializing all weight matrices. | |||
| """ | |||
| if isinstance(vocab_size_or_config_json_file, str): | |||
| with open( | |||
| vocab_size_or_config_json_file, 'r', | |||
| encoding='utf-8') as reader: | |||
| json_config = json.loads(reader.read()) | |||
| for key, value in json_config.items(): | |||
| self.__dict__[key] = value | |||
| elif isinstance(vocab_size_or_config_json_file, int): | |||
| self.vocab_size = vocab_size_or_config_json_file | |||
| self.hidden_size = hidden_size | |||
| self.num_hidden_layers = num_hidden_layers | |||
| self.num_attention_heads = num_attention_heads | |||
| self.hidden_act = hidden_act | |||
| self.intermediate_size = intermediate_size | |||
| self.hidden_dropout_prob = hidden_dropout_prob | |||
| self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||
| self.max_position_embeddings = max_position_embeddings | |||
| self.type_vocab_size = type_vocab_size | |||
| self.initializer_range = initializer_range | |||
| else: | |||
| raise ValueError( | |||
| 'First argument must be either a vocabulary size (int)' | |||
| 'or the path to a pretrained model config file (str)') | |||
| @classmethod | |||
| def from_dict(cls, json_object): | |||
| """Constructs a `Star3Config` from a Python dictionary of parameters.""" | |||
| config = Star3Config(vocab_size_or_config_json_file=-1) | |||
| for key, value in json_object.items(): | |||
| config.__dict__[key] = value | |||
| return config | |||
| @classmethod | |||
| def from_json_file(cls, json_file): | |||
| """Constructs a `Star3Config` from a json file of parameters.""" | |||
| with open(json_file, 'r', encoding='utf-8') as reader: | |||
| text = reader.read() | |||
| return cls.from_dict(json.loads(text)) | |||
| def __repr__(self): | |||
| return str(self.to_json_string()) | |||
| def to_dict(self): | |||
| """Serializes this instance to a Python dictionary.""" | |||
| output = copy.deepcopy(self.__dict__) | |||
| return output | |||
| def to_json_string(self): | |||
| """Serializes this instance to a JSON string.""" | |||
| return json.dumps(self.to_dict(), indent=2, sort_keys=True) + '\n' | |||
| @@ -0,0 +1,747 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from typing import Dict, Optional | |||
| import numpy | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from transformers import BertTokenizer | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import Model, Tensor | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.nlp.star3.configuration_star3 import Star3Config | |||
| from modelscope.models.nlp.star3.modeling_star3 import Seq2SQL, Star3Model | |||
| from modelscope.preprocessors.star3.fields.struct import Constant | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.device import verify_device | |||
| __all__ = ['TableQuestionAnswering'] | |||
| @MODELS.register_module( | |||
| Tasks.table_question_answering, module_name=Models.star3) | |||
| class TableQuestionAnswering(Model): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the table-question-answering model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self.tokenizer = BertTokenizer( | |||
| os.path.join(model_dir, ModelFile.VOCAB_FILE)) | |||
| device_name = kwargs.get('device', 'gpu') | |||
| verify_device(device_name) | |||
| self._device_name = device_name | |||
| state_dict = torch.load( | |||
| os.path.join(self.model_dir, ModelFile.TORCH_MODEL_BIN_FILE), | |||
| map_location='cpu') | |||
| self.backbone_config = Star3Config.from_json_file( | |||
| os.path.join(self.model_dir, ModelFile.CONFIGURATION)) | |||
| self.backbone_model = Star3Model( | |||
| config=self.backbone_config, schema_link_module='rat') | |||
| self.backbone_model.load_state_dict(state_dict['backbone_model']) | |||
| constant = Constant() | |||
| self.agg_ops = constant.agg_ops | |||
| self.cond_ops = constant.cond_ops | |||
| self.cond_conn_ops = constant.cond_conn_ops | |||
| self.action_ops = constant.action_ops | |||
| self.max_select_num = constant.max_select_num | |||
| self.max_where_num = constant.max_where_num | |||
| self.col_type_dict = constant.col_type_dict | |||
| self.schema_link_dict = constant.schema_link_dict | |||
| n_cond_ops = len(self.cond_ops) | |||
| n_agg_ops = len(self.agg_ops) | |||
| n_action_ops = len(self.action_ops) | |||
| iS = self.backbone_config.hidden_size | |||
| self.head_model = Seq2SQL(iS, 100, 2, 0.0, n_cond_ops, n_agg_ops, | |||
| n_action_ops, self.max_select_num, | |||
| self.max_where_num, self._device_name) | |||
| self.head_model.load_state_dict(state_dict['head_model'], strict=False) | |||
| self.backbone_model.to(self._device_name) | |||
| self.head_model.to(self._device_name) | |||
| def convert_string(self, pr_wvi, nlu, nlu_tt): | |||
| convs = [] | |||
| for b, nlu1 in enumerate(nlu): | |||
| conv_dict = {} | |||
| nlu_tt1 = nlu_tt[b] | |||
| idx = 0 | |||
| convflag = True | |||
| for i, ntok in enumerate(nlu_tt1): | |||
| if idx >= len(nlu1): | |||
| convflag = False | |||
| break | |||
| if ntok.startswith('##'): | |||
| ntok = ntok.replace('##', '') | |||
| tok = nlu1[idx:idx + 1].lower() | |||
| if ntok == tok: | |||
| conv_dict[i] = [idx, idx + 1] | |||
| idx += 1 | |||
| elif ntok == '#': | |||
| conv_dict[i] = [idx, idx] | |||
| elif ntok == '[UNK]': | |||
| conv_dict[i] = [idx, idx + 1] | |||
| j = i + 1 | |||
| idx += 1 | |||
| if idx < len(nlu1) and j < len( | |||
| nlu_tt1) and nlu_tt1[j] != '[UNK]': | |||
| while idx < len(nlu1): | |||
| val = nlu1[idx:idx + 1].lower() | |||
| if nlu_tt1[j].startswith(val): | |||
| break | |||
| idx += 1 | |||
| conv_dict[i][1] = idx | |||
| elif tok in ntok: | |||
| startid = idx | |||
| idx += 1 | |||
| while idx < len(nlu1): | |||
| tok += nlu1[idx:idx + 1].lower() | |||
| if ntok == tok: | |||
| conv_dict[i] = [startid, idx + 1] | |||
| break | |||
| idx += 1 | |||
| idx += 1 | |||
| else: | |||
| convflag = False | |||
| conv = [] | |||
| if convflag: | |||
| for pr_wvi1 in pr_wvi[b]: | |||
| s1, e1 = conv_dict[pr_wvi1[0]] | |||
| s2, e2 = conv_dict[pr_wvi1[1]] | |||
| newidx = pr_wvi1[1] | |||
| while newidx + 1 < len( | |||
| nlu_tt1) and s2 == e2 and nlu_tt1[newidx] == '#': | |||
| newidx += 1 | |||
| s2, e2 = conv_dict[newidx] | |||
| if newidx + 1 < len(nlu_tt1) and nlu_tt1[ | |||
| newidx + 1].startswith('##'): | |||
| s2, e2 = conv_dict[newidx + 1] | |||
| phrase = nlu1[s1:e2] | |||
| conv.append(phrase) | |||
| else: | |||
| for pr_wvi1 in pr_wvi[b]: | |||
| phrase = ''.join(nlu_tt1[pr_wvi1[0]:pr_wvi1[1] | |||
| + 1]).replace('##', '') | |||
| conv.append(phrase) | |||
| convs.append(conv) | |||
| return convs | |||
| def get_fields_info(self, t1s, tables, train=True): | |||
| nlu, nlu_t, sql_i, q_know, t_know, action, hs_t, types, units, his_sql, schema_link = \ | |||
| [], [], [], [], [], [], [], [], [], [], [] | |||
| for t1 in t1s: | |||
| nlu.append(t1['question']) | |||
| nlu_t.append(t1['question_tok']) | |||
| hs_t.append(t1['header_tok']) | |||
| q_know.append(t1['bertindex_knowledge']) | |||
| t_know.append(t1['header_knowledge']) | |||
| types.append(t1['types']) | |||
| units.append(t1['units']) | |||
| his_sql.append(t1.get('history_sql', None)) | |||
| schema_link.append(t1.get('schema_link', [])) | |||
| if train: | |||
| action.append(t1.get('action', [0])) | |||
| sql_i.append(t1['sql']) | |||
| return nlu, nlu_t, sql_i, q_know, t_know, action, hs_t, types, units, his_sql, schema_link | |||
| def get_history_select_where(self, his_sql, header_len): | |||
| if his_sql is None: | |||
| return [0], [0] | |||
| sel = [] | |||
| for seli in his_sql['sel']: | |||
| if seli + 1 < header_len and seli + 1 not in sel: | |||
| sel.append(seli + 1) | |||
| whe = [] | |||
| for condi in his_sql['conds']: | |||
| if condi[0] + 1 < header_len and condi[0] + 1 not in whe: | |||
| whe.append(condi[0] + 1) | |||
| if len(sel) == 0: | |||
| sel.append(0) | |||
| if len(whe) == 0: | |||
| whe.append(0) | |||
| sel.sort() | |||
| whe.sort() | |||
| return sel, whe | |||
| def get_types_ids(self, col_type): | |||
| for key, type_ids in self.col_type_dict.items(): | |||
| if key in col_type.lower(): | |||
| return type_ids | |||
| return self.col_type_dict['null'] | |||
| def generate_inputs(self, nlu1_tok, hs_t_1, type_t, unit_t, his_sql, | |||
| q_know, t_know, s_link): | |||
| tokens = [] | |||
| orders = [] | |||
| types = [] | |||
| segment_ids = [] | |||
| matchs = [] | |||
| col_dict = {} | |||
| schema_tok = [] | |||
| tokens.append('[CLS]') | |||
| orders.append(0) | |||
| types.append(0) | |||
| i_st_nlu = len(tokens) | |||
| matchs.append(0) | |||
| segment_ids.append(0) | |||
| for idx, token in enumerate(nlu1_tok): | |||
| if q_know[idx] == 100: | |||
| break | |||
| elif q_know[idx] >= 5: | |||
| matchs.append(1) | |||
| else: | |||
| matchs.append(q_know[idx] + 1) | |||
| tokens.append(token) | |||
| orders.append(0) | |||
| types.append(0) | |||
| segment_ids.append(0) | |||
| i_ed_nlu = len(tokens) | |||
| tokens.append('[SEP]') | |||
| orders.append(0) | |||
| types.append(0) | |||
| matchs.append(0) | |||
| segment_ids.append(0) | |||
| sel, whe = self.get_history_select_where(his_sql, len(hs_t_1)) | |||
| if len(sel) == 1 and sel[0] == 0 \ | |||
| and len(whe) == 1 and whe[0] == 0: | |||
| pass | |||
| else: | |||
| tokens.append('select') | |||
| orders.append(0) | |||
| types.append(0) | |||
| matchs.append(10) | |||
| segment_ids.append(0) | |||
| for seli in sel: | |||
| tokens.append('[PAD]') | |||
| orders.append(0) | |||
| types.append(0) | |||
| matchs.append(10) | |||
| segment_ids.append(0) | |||
| col_dict[len(tokens) - 1] = seli | |||
| tokens.append('where') | |||
| orders.append(0) | |||
| types.append(0) | |||
| matchs.append(10) | |||
| segment_ids.append(0) | |||
| for whei in whe: | |||
| tokens.append('[PAD]') | |||
| orders.append(0) | |||
| types.append(0) | |||
| matchs.append(10) | |||
| segment_ids.append(0) | |||
| col_dict[len(tokens) - 1] = whei | |||
| tokens.append('[SEP]') | |||
| orders.append(0) | |||
| types.append(0) | |||
| matchs.append(10) | |||
| segment_ids.append(0) | |||
| column_start = len(tokens) | |||
| i_hds_f = [] | |||
| header_flatten_tokens, header_flatten_index = [], [] | |||
| for i, hds11 in enumerate(hs_t_1): | |||
| if len(unit_t[i]) == 1 and unit_t[i][0] == 'null': | |||
| temp_header_tokens = hds11 | |||
| else: | |||
| temp_header_tokens = hds11 + unit_t[i] | |||
| schema_tok.append(temp_header_tokens) | |||
| header_flatten_tokens.extend(temp_header_tokens) | |||
| header_flatten_index.extend([i + 1] * len(temp_header_tokens)) | |||
| i_st_hd_f = len(tokens) | |||
| tokens += ['[PAD]'] | |||
| orders.append(0) | |||
| types.append(self.get_types_ids(type_t[i])) | |||
| i_ed_hd_f = len(tokens) | |||
| col_dict[len(tokens) - 1] = i | |||
| i_hds_f.append((i_st_hd_f, i_ed_hd_f)) | |||
| if i == 0: | |||
| matchs.append(6) | |||
| else: | |||
| matchs.append(t_know[i - 1] + 6) | |||
| segment_ids.append(1) | |||
| tokens.append('[SEP]') | |||
| orders.append(0) | |||
| types.append(0) | |||
| matchs.append(0) | |||
| segment_ids.append(1) | |||
| # position where | |||
| # [SEP] | |||
| start_ids = len(tokens) - 1 | |||
| tokens.append('action') # action | |||
| orders.append(1) | |||
| types.append(0) | |||
| matchs.append(0) | |||
| segment_ids.append(1) | |||
| tokens.append('connect') # column | |||
| orders.append(1) | |||
| types.append(0) | |||
| matchs.append(0) | |||
| segment_ids.append(1) | |||
| tokens.append('allen') # select len | |||
| orders.append(1) | |||
| types.append(0) | |||
| matchs.append(0) | |||
| segment_ids.append(1) | |||
| for x in range(self.max_where_num): | |||
| tokens.append('act') # op | |||
| orders.append(2 + x) | |||
| types.append(0) | |||
| matchs.append(0) | |||
| segment_ids.append(1) | |||
| tokens.append('size') # where len | |||
| orders.append(1) | |||
| types.append(0) | |||
| matchs.append(0) | |||
| segment_ids.append(1) | |||
| for x in range(self.max_select_num): | |||
| tokens.append('focus') # agg | |||
| orders.append(2 + x) | |||
| types.append(0) | |||
| matchs.append(0) | |||
| segment_ids.append(1) | |||
| i_nlu = (i_st_nlu, i_ed_nlu) | |||
| schema_link_matrix = numpy.zeros((len(tokens), len(tokens)), | |||
| dtype='int32') | |||
| schema_link_mask = numpy.zeros((len(tokens), len(tokens)), | |||
| dtype='float32') | |||
| for relation in s_link: | |||
| if relation['label'] in ['col', 'val']: | |||
| [q_st, q_ed] = relation['question_index'] | |||
| cid = max(0, relation['column_index']) | |||
| schema_link_matrix[ | |||
| i_st_nlu + q_st: i_st_nlu + q_ed + 1, | |||
| column_start + cid + 1: column_start + cid + 1 + 1] = \ | |||
| self.schema_link_dict[relation['label'] + '_middle'] | |||
| schema_link_matrix[ | |||
| i_st_nlu + q_st, | |||
| column_start + cid + 1: column_start + cid + 1 + 1] = \ | |||
| self.schema_link_dict[relation['label'] + '_start'] | |||
| schema_link_matrix[ | |||
| i_st_nlu + q_ed, | |||
| column_start + cid + 1: column_start + cid + 1 + 1] = \ | |||
| self.schema_link_dict[relation['label'] + '_end'] | |||
| schema_link_mask[i_st_nlu + q_st:i_st_nlu + q_ed + 1, | |||
| column_start + cid + 1:column_start + cid + 1 | |||
| + 1] = 1.0 | |||
| return tokens, orders, types, segment_ids, matchs, \ | |||
| i_nlu, i_hds_f, start_ids, column_start, col_dict, schema_tok, \ | |||
| header_flatten_tokens, header_flatten_index, schema_link_matrix, schema_link_mask | |||
| def gen_l_hpu(self, i_hds): | |||
| """ | |||
| Treat columns as if it is a batch of natural language utterance | |||
| with batch-size = # of columns * # of batch_size | |||
| i_hds = [(17, 18), (19, 21), (22, 23), (24, 25), (26, 29), (30, 34)]) | |||
| """ | |||
| l_hpu = [] | |||
| for i_hds1 in i_hds: | |||
| for i_hds11 in i_hds1: | |||
| l_hpu.append(i_hds11[1] - i_hds11[0]) | |||
| return l_hpu | |||
| def get_bert_output(self, model_bert, tokenizer, nlu_t, hs_t, col_types, | |||
| units, his_sql, q_know, t_know, schema_link): | |||
| """ | |||
| Here, input is toknized further by WordPiece (WP) tokenizer and fed into BERT. | |||
| INPUT | |||
| :param model_bert: | |||
| :param tokenizer: WordPiece toknizer | |||
| :param nlu: Question | |||
| :param nlu_t: CoreNLP tokenized nlu. | |||
| :param hds: Headers | |||
| :param hs_t: None or 1st-level tokenized headers | |||
| :param max_seq_length: max input token length | |||
| OUTPUT | |||
| tokens: BERT input tokens | |||
| nlu_tt: WP-tokenized input natural language questions | |||
| orig_to_tok_index: map the index of 1st-level-token to the index of 2nd-level-token | |||
| tok_to_orig_index: inverse map. | |||
| """ | |||
| l_n = [] | |||
| l_hs = [] # The length of columns for each batch | |||
| input_ids = [] | |||
| order_ids = [] | |||
| type_ids = [] | |||
| segment_ids = [] | |||
| match_ids = [] | |||
| input_mask = [] | |||
| i_nlu = [ | |||
| ] # index to retreive the position of contextual vector later. | |||
| i_hds = [] | |||
| tokens = [] | |||
| orders = [] | |||
| types = [] | |||
| matchs = [] | |||
| segments = [] | |||
| schema_link_matrix_list = [] | |||
| schema_link_mask_list = [] | |||
| start_index = [] | |||
| column_index = [] | |||
| col_dict_list = [] | |||
| header_list = [] | |||
| header_flatten_token_list = [] | |||
| header_flatten_tokenid_list = [] | |||
| header_flatten_index_list = [] | |||
| header_tok_max_len = 0 | |||
| cur_max_length = 0 | |||
| for b, nlu_t1 in enumerate(nlu_t): | |||
| hs_t1 = [hs_t[b][-1]] + hs_t[b][:-1] | |||
| type_t1 = [col_types[b][-1]] + col_types[b][:-1] | |||
| unit_t1 = [units[b][-1]] + units[b][:-1] | |||
| l_hs.append(len(hs_t1)) | |||
| # [CLS] nlu [SEP] col1 [SEP] col2 [SEP] ...col-n [SEP] | |||
| # 2. Generate BERT inputs & indices. | |||
| tokens1, orders1, types1, segment1, match1, i_nlu1, i_hds_1, \ | |||
| start_idx, column_start, col_dict, schema_tok, \ | |||
| header_flatten_tokens, header_flatten_index, schema_link_matrix, schema_link_mask = \ | |||
| self.generate_inputs( | |||
| nlu_t1, hs_t1, type_t1, unit_t1, his_sql[b], | |||
| q_know[b], t_know[b], schema_link[b]) | |||
| l_n.append(i_nlu1[1] - i_nlu1[0]) | |||
| start_index.append(start_idx) | |||
| column_index.append(column_start) | |||
| col_dict_list.append(col_dict) | |||
| tokens.append(tokens1) | |||
| orders.append(orders1) | |||
| types.append(types1) | |||
| segments.append(segment1) | |||
| matchs.append(match1) | |||
| i_nlu.append(i_nlu1) | |||
| i_hds.append(i_hds_1) | |||
| schema_link_matrix_list.append(schema_link_matrix) | |||
| schema_link_mask_list.append(schema_link_mask) | |||
| header_flatten_token_list.append(header_flatten_tokens) | |||
| header_flatten_index_list.append(header_flatten_index) | |||
| header_list.append(schema_tok) | |||
| header_max = max([len(schema_tok1) for schema_tok1 in schema_tok]) | |||
| if header_max > header_tok_max_len: | |||
| header_tok_max_len = header_max | |||
| if len(tokens1) > cur_max_length: | |||
| cur_max_length = len(tokens1) | |||
| if len(tokens1) > 512: | |||
| print('input too long!!! total_num:%d\t question:%s' % | |||
| (len(tokens1), ''.join(nlu_t1))) | |||
| assert cur_max_length <= 512 | |||
| for i, tokens1 in enumerate(tokens): | |||
| segment_ids1 = segments[i] | |||
| order_ids1 = orders[i] | |||
| type_ids1 = types[i] | |||
| match_ids1 = matchs[i] | |||
| input_ids1 = tokenizer.convert_tokens_to_ids(tokens1) | |||
| input_mask1 = [1] * len(input_ids1) | |||
| while len(input_ids1) < cur_max_length: | |||
| input_ids1.append(0) | |||
| input_mask1.append(0) | |||
| segment_ids1.append(0) | |||
| order_ids1.append(0) | |||
| type_ids1.append(0) | |||
| match_ids1.append(0) | |||
| if len(input_ids1) != cur_max_length: | |||
| print('Error: ', nlu_t1, tokens1, len(input_ids1), | |||
| cur_max_length) | |||
| assert len(input_ids1) == cur_max_length | |||
| assert len(input_mask1) == cur_max_length | |||
| assert len(order_ids1) == cur_max_length | |||
| assert len(segment_ids1) == cur_max_length | |||
| assert len(match_ids1) == cur_max_length | |||
| assert len(type_ids1) == cur_max_length | |||
| input_ids.append(input_ids1) | |||
| order_ids.append(order_ids1) | |||
| type_ids.append(type_ids1) | |||
| segment_ids.append(segment_ids1) | |||
| input_mask.append(input_mask1) | |||
| match_ids.append(match_ids1) | |||
| header_len = [] | |||
| header_ids = [] | |||
| header_max_len = max( | |||
| [len(header_list1) for header_list1 in header_list]) | |||
| for header1 in header_list: | |||
| header_len1 = [] | |||
| header_ids1 = [] | |||
| for header_tok in header1: | |||
| header_len1.append(len(header_tok)) | |||
| header_tok_ids1 = tokenizer.convert_tokens_to_ids(header_tok) | |||
| while len(header_tok_ids1) < header_tok_max_len: | |||
| header_tok_ids1.append(0) | |||
| header_ids1.append(header_tok_ids1) | |||
| while len(header_ids1) < header_max_len: | |||
| header_ids1.append([0] * header_tok_max_len) | |||
| header_len.append(header_len1) | |||
| header_ids.append(header_ids1) | |||
| for i, header_flatten_token in enumerate(header_flatten_token_list): | |||
| header_flatten_tokenid = tokenizer.convert_tokens_to_ids( | |||
| header_flatten_token) | |||
| header_flatten_tokenid_list.append(header_flatten_tokenid) | |||
| # Convert to tensor | |||
| all_input_ids = torch.tensor( | |||
| input_ids, dtype=torch.long).to(self._device_name) | |||
| all_order_ids = torch.tensor( | |||
| order_ids, dtype=torch.long).to(self._device_name) | |||
| all_type_ids = torch.tensor( | |||
| type_ids, dtype=torch.long).to(self._device_name) | |||
| all_input_mask = torch.tensor( | |||
| input_mask, dtype=torch.long).to(self._device_name) | |||
| all_segment_ids = torch.tensor( | |||
| segment_ids, dtype=torch.long).to(self._device_name) | |||
| all_match_ids = torch.tensor( | |||
| match_ids, dtype=torch.long).to(self._device_name) | |||
| all_header_ids = torch.tensor( | |||
| header_ids, dtype=torch.long).to(self._device_name) | |||
| all_ids = torch.arange( | |||
| all_input_ids.shape[0], dtype=torch.long).to(self._device_name) | |||
| bS = len(header_flatten_tokenid_list) | |||
| max_header_flatten_token_length = max( | |||
| [len(x) for x in header_flatten_tokenid_list]) | |||
| all_header_flatten_tokens = numpy.zeros( | |||
| (bS, max_header_flatten_token_length), dtype='int32') | |||
| all_header_flatten_index = numpy.zeros( | |||
| (bS, max_header_flatten_token_length), dtype='int32') | |||
| for i, header_flatten_tokenid in enumerate( | |||
| header_flatten_tokenid_list): | |||
| for j, tokenid in enumerate(header_flatten_tokenid): | |||
| all_header_flatten_tokens[i, j] = tokenid | |||
| for j, hdindex in enumerate(header_flatten_index_list[i]): | |||
| all_header_flatten_index[i, j] = hdindex | |||
| all_header_flatten_output = numpy.zeros((bS, header_max_len + 1), | |||
| dtype='int32') | |||
| all_header_flatten_tokens = torch.tensor( | |||
| all_header_flatten_tokens, dtype=torch.long).to(self._device_name) | |||
| all_header_flatten_index = torch.tensor( | |||
| all_header_flatten_index, dtype=torch.long).to(self._device_name) | |||
| all_header_flatten_output = torch.tensor( | |||
| all_header_flatten_output, | |||
| dtype=torch.float32).to(self._device_name) | |||
| all_token_column_id = numpy.zeros((bS, cur_max_length), dtype='int32') | |||
| all_token_column_mask = numpy.zeros((bS, cur_max_length), | |||
| dtype='float32') | |||
| for bi, col_dict in enumerate(col_dict_list): | |||
| for ki, vi in col_dict.items(): | |||
| all_token_column_id[bi, ki] = vi + 1 | |||
| all_token_column_mask[bi, ki] = 1.0 | |||
| all_token_column_id = torch.tensor( | |||
| all_token_column_id, dtype=torch.long).to(self._device_name) | |||
| all_token_column_mask = torch.tensor( | |||
| all_token_column_mask, dtype=torch.float32).to(self._device_name) | |||
| all_schema_link_matrix = numpy.zeros( | |||
| (bS, cur_max_length, cur_max_length), dtype='int32') | |||
| all_schema_link_mask = numpy.zeros( | |||
| (bS, cur_max_length, cur_max_length), dtype='float32') | |||
| for i, schema_link_matrix in enumerate(schema_link_matrix_list): | |||
| temp_len = schema_link_matrix.shape[0] | |||
| all_schema_link_matrix[i, 0:temp_len, | |||
| 0:temp_len] = schema_link_matrix | |||
| all_schema_link_mask[i, 0:temp_len, | |||
| 0:temp_len] = schema_link_mask_list[i] | |||
| all_schema_link_matrix = torch.tensor( | |||
| all_schema_link_matrix, dtype=torch.long).to(self._device_name) | |||
| all_schema_link_mask = torch.tensor( | |||
| all_schema_link_mask, dtype=torch.long).to(self._device_name) | |||
| # 5. generate l_hpu from i_hds | |||
| l_hpu = self.gen_l_hpu(i_hds) | |||
| # 4. Generate BERT output. | |||
| all_encoder_layer, pooled_output = model_bert( | |||
| all_input_ids, | |||
| all_header_ids, | |||
| token_order_ids=all_order_ids, | |||
| token_type_ids=all_segment_ids, | |||
| attention_mask=all_input_mask, | |||
| match_type_ids=all_match_ids, | |||
| l_hs=l_hs, | |||
| header_len=header_len, | |||
| type_ids=all_type_ids, | |||
| col_dict_list=col_dict_list, | |||
| ids=all_ids, | |||
| header_flatten_tokens=all_header_flatten_tokens, | |||
| header_flatten_index=all_header_flatten_index, | |||
| header_flatten_output=all_header_flatten_output, | |||
| token_column_id=all_token_column_id, | |||
| token_column_mask=all_token_column_mask, | |||
| column_start_index=column_index, | |||
| headers_length=l_hs, | |||
| all_schema_link_matrix=all_schema_link_matrix, | |||
| all_schema_link_mask=all_schema_link_mask, | |||
| output_all_encoded_layers=False) | |||
| return all_encoder_layer, pooled_output, tokens, i_nlu, i_hds, \ | |||
| l_n, l_hpu, l_hs, start_index, column_index, all_ids | |||
| def predict(self, querys): | |||
| self.head_model.eval() | |||
| self.backbone_model.eval() | |||
| nlu, nlu_t, sql_i, q_know, t_know, tb, hs_t, types, units, his_sql, schema_link = \ | |||
| self.get_fields_info(querys, None, train=False) | |||
| with torch.no_grad(): | |||
| all_encoder_layer, _, tokens, i_nlu, i_hds, l_n, l_hpu, l_hs, start_index, column_index, ids = \ | |||
| self.get_bert_output( | |||
| self.backbone_model, self.tokenizer, | |||
| nlu_t, hs_t, types, units, his_sql, q_know, t_know, schema_link) | |||
| s_action, s_sc, s_sa, s_cco, s_wc, s_wo, s_wvs, s_len = self.head_model( | |||
| all_encoder_layer, l_n, l_hs, start_index, column_index, | |||
| tokens, ids) | |||
| action_batch = torch.argmax(F.softmax(s_action, -1), -1).cpu().tolist() | |||
| scco_batch = torch.argmax(F.softmax(s_cco, -1), -1).cpu().tolist() | |||
| sc_batch = torch.argmax(F.softmax(s_sc, -1), -1).cpu().tolist() | |||
| sa_batch = torch.argmax(F.softmax(s_sa, -1), -1).cpu().tolist() | |||
| wc_batch = torch.argmax(F.softmax(s_wc, -1), -1).cpu().tolist() | |||
| wo_batch = torch.argmax(F.softmax(s_wo, -1), -1).cpu().tolist() | |||
| s_wvs_s, s_wvs_e = s_wvs | |||
| wvss_batch = torch.argmax(F.softmax(s_wvs_s, -1), -1).cpu().tolist() | |||
| wvse_batch = torch.argmax(F.softmax(s_wvs_e, -1), -1).cpu().tolist() | |||
| s_slen, s_wlen = s_len | |||
| slen_batch = torch.argmax(F.softmax(s_slen, -1), -1).cpu().tolist() | |||
| wlen_batch = torch.argmax(F.softmax(s_wlen, -1), -1).cpu().tolist() | |||
| pr_wvi = [] | |||
| for i in range(len(querys)): | |||
| wvi = [] | |||
| for j in range(wlen_batch[i]): | |||
| wvi.append([ | |||
| max(0, wvss_batch[i][j] - 1), | |||
| max(0, wvse_batch[i][j] - 1) | |||
| ]) | |||
| pr_wvi.append(wvi) | |||
| pr_wvi_str = self.convert_string(pr_wvi, nlu, nlu_t) | |||
| pre_results = [] | |||
| for ib in range(len(querys)): | |||
| res_one = {} | |||
| sql = {} | |||
| sql['cond_conn_op'] = scco_batch[ib] | |||
| sl = slen_batch[ib] | |||
| sql['sel'] = list( | |||
| numpy.array(sc_batch[ib][:sl]).astype(numpy.int32) - 1) | |||
| sql['agg'] = list( | |||
| numpy.array(sa_batch[ib][:sl]).astype(numpy.int32)) | |||
| sels = [] | |||
| aggs = [] | |||
| for ia, sel in enumerate(sql['sel']): | |||
| if sel == -1: | |||
| if sql['agg'][ia] > 0: | |||
| sels.append(l_hs[ib] - 1) | |||
| aggs.append(sql['agg'][ia]) | |||
| continue | |||
| sels.append(sel) | |||
| if sql['agg'][ia] == -1: | |||
| aggs.append(0) | |||
| else: | |||
| aggs.append(sql['agg'][ia]) | |||
| if len(sels) == 0: | |||
| sels.append(l_hs[ib] - 1) | |||
| aggs.append(0) | |||
| assert len(sels) == len(aggs) | |||
| sql['sel'] = sels | |||
| sql['agg'] = aggs | |||
| conds = [] | |||
| wl = wlen_batch[ib] | |||
| wc_os = list( | |||
| numpy.array(wc_batch[ib][:wl]).astype(numpy.int32) - 1) | |||
| wo_os = list(numpy.array(wo_batch[ib][:wl]).astype(numpy.int32)) | |||
| res_one['question_tok'] = querys[ib]['question_tok'] | |||
| for i in range(wl): | |||
| if wc_os[i] == -1: | |||
| continue | |||
| conds.append([wc_os[i], wo_os[i], pr_wvi_str[ib][i]]) | |||
| if len(conds) == 0: | |||
| conds.append([l_hs[ib] - 1, 2, 'Nulll']) | |||
| sql['conds'] = conds | |||
| res_one['question'] = querys[ib]['question'] | |||
| res_one['table_id'] = querys[ib]['table_id'] | |||
| res_one['sql'] = sql | |||
| res_one['action'] = action_batch[ib] | |||
| res_one['model_out'] = [ | |||
| sc_batch[ib], sa_batch[ib], wc_batch[ib], wo_batch[ib], | |||
| wvss_batch[ib], wvse_batch[ib] | |||
| ] | |||
| pre_results.append(res_one) | |||
| return pre_results | |||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Tensor]): the preprocessed data | |||
| Returns: | |||
| Dict[str, Tensor]: results | |||
| Example: | |||
| """ | |||
| result = self.predict(input['datas'])[0] | |||
| return { | |||
| 'result': result, | |||
| 'history_sql': input['datas'][0]['history_sql'] | |||
| } | |||
| @@ -35,6 +35,7 @@ class OutputKeys(object): | |||
| UUID = 'uuid' | |||
| WORD = 'word' | |||
| KWS_LIST = 'kws_list' | |||
| HISTORY = 'history' | |||
| TIMESTAMPS = 'timestamps' | |||
| SPLIT_VIDEO_NUM = 'split_video_num' | |||
| SPLIT_META_DICT = 'split_meta_dict' | |||
| @@ -471,6 +472,13 @@ TASK_OUTPUTS = { | |||
| # } | |||
| Tasks.conversational_text_to_sql: [OutputKeys.TEXT], | |||
| # table-question-answering result for single sample | |||
| # { | |||
| # "sql": "SELECT shop.Name FROM shop." | |||
| # "sql_history": {sel: 0, agg: 0, conds: [[0, 0, 'val']]} | |||
| # } | |||
| Tasks.table_question_answering: [OutputKeys.OUTPUT, OutputKeys.HISTORY], | |||
| # ============ audio tasks =================== | |||
| # asr result for single sample | |||
| # { "text": "每一天都要快乐喔"} | |||
| @@ -66,6 +66,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.conversational_text_to_sql: | |||
| (Pipelines.conversational_text_to_sql, | |||
| 'damo/nlp_star_conversational-text-to-sql'), | |||
| Tasks.table_question_answering: | |||
| (Pipelines.table_question_answering_pipeline, | |||
| 'damo/nlp-convai-text2sql-pretrain-cn'), | |||
| Tasks.text_error_correction: | |||
| (Pipelines.text_error_correction, | |||
| 'damo/nlp_bart_text-error-correction_chinese'), | |||
| @@ -5,6 +5,7 @@ from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .conversational_text_to_sql_pipeline import ConversationalTextToSqlPipeline | |||
| from .table_question_answering_pipeline import TableQuestionAnsweringPipeline | |||
| from .dialog_intent_prediction_pipeline import DialogIntentPredictionPipeline | |||
| from .dialog_modeling_pipeline import DialogModelingPipeline | |||
| from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline | |||
| @@ -31,6 +32,8 @@ else: | |||
| _import_structure = { | |||
| 'conversational_text_to_sql_pipeline': | |||
| ['ConversationalTextToSqlPipeline'], | |||
| 'table_question_answering_pipeline': | |||
| ['TableQuestionAnsweringPipeline'], | |||
| 'dialog_intent_prediction_pipeline': | |||
| ['DialogIntentPredictionPipeline'], | |||
| 'dialog_modeling_pipeline': ['DialogModelingPipeline'], | |||
| @@ -0,0 +1,284 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from typing import Any, Dict, Union | |||
| import torch | |||
| from transformers import BertTokenizer | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import TableQuestionAnswering | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import TableQuestionAnsweringPreprocessor | |||
| from modelscope.preprocessors.star3.fields.database import Database | |||
| from modelscope.preprocessors.star3.fields.struct import Constant, SQLQuery | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| __all__ = ['TableQuestionAnsweringPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.table_question_answering, | |||
| module_name=Pipelines.table_question_answering_pipeline) | |||
| class TableQuestionAnsweringPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[TableQuestionAnswering, str], | |||
| preprocessor: TableQuestionAnsweringPreprocessor = None, | |||
| db: Database = None, | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create a table question answering prediction pipeline | |||
| Args: | |||
| model (TableQuestionAnswering): a model instance | |||
| preprocessor (TableQuestionAnsweringPreprocessor): a preprocessor instance | |||
| db (Database): a database to store tables in the database | |||
| """ | |||
| model = model if isinstance( | |||
| model, TableQuestionAnswering) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = TableQuestionAnsweringPreprocessor(model.model_dir) | |||
| # initilize tokenizer | |||
| self.tokenizer = BertTokenizer( | |||
| os.path.join(model.model_dir, ModelFile.VOCAB_FILE)) | |||
| # initialize database | |||
| if db is None: | |||
| self.db = Database( | |||
| tokenizer=self.tokenizer, | |||
| table_file_path=os.path.join(model.model_dir, 'table.json'), | |||
| syn_dict_file_path=os.path.join(model.model_dir, | |||
| 'synonym.txt')) | |||
| else: | |||
| self.db = db | |||
| constant = Constant() | |||
| self.agg_ops = constant.agg_ops | |||
| self.cond_ops = constant.cond_ops | |||
| self.cond_conn_ops = constant.cond_conn_ops | |||
| self.action_ops = constant.action_ops | |||
| self.max_select_num = constant.max_select_num | |||
| self.max_where_num = constant.max_where_num | |||
| self.col_type_dict = constant.col_type_dict | |||
| self.schema_link_dict = constant.schema_link_dict | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| def post_process_multi_turn(self, history_sql, result, table): | |||
| action = self.action_ops[result['action']] | |||
| headers = table['header_name'] | |||
| current_sql = result['sql'] | |||
| if history_sql is None: | |||
| return current_sql | |||
| if action == 'out_of_scripts': | |||
| return history_sql | |||
| elif action == 'switch_table': | |||
| return current_sql | |||
| elif action == 'restart': | |||
| return current_sql | |||
| elif action == 'firstTurn': | |||
| return current_sql | |||
| elif action == 'del_focus': | |||
| pre_final_sql = copy.deepcopy(history_sql) | |||
| pre_sels = [] | |||
| pre_aggs = [] | |||
| for idx, seli in enumerate(pre_final_sql['sel']): | |||
| if seli not in current_sql['sel']: | |||
| pre_sels.append(seli) | |||
| pre_aggs.append(pre_final_sql['agg'][idx]) | |||
| if len(pre_sels) < 1: | |||
| pre_sels.append(len(headers)) | |||
| pre_aggs.append(0) | |||
| pre_final_sql['sel'] = pre_sels | |||
| pre_final_sql['agg'] = pre_aggs | |||
| final_conds = [] | |||
| for condi in pre_final_sql['conds']: | |||
| if condi[0] < len(headers): | |||
| final_conds.append(condi) | |||
| if len(final_conds) < 1: | |||
| final_conds.append([len(headers), 2, 'Null']) | |||
| pre_final_sql['conds'] = final_conds | |||
| return pre_final_sql | |||
| elif action == 'change_agg_only': | |||
| pre_final_sql = history_sql | |||
| pre_sels = [] | |||
| pre_aggs = [] | |||
| for idx, seli in enumerate(pre_final_sql['sel']): | |||
| if seli in current_sql['sel']: | |||
| pre_sels.append(seli) | |||
| changed_aggi = -1 | |||
| for idx_single, aggi in enumerate(current_sql['agg']): | |||
| if current_sql['sel'][idx_single] == seli: | |||
| changed_aggi = aggi | |||
| pre_aggs.append(changed_aggi) | |||
| else: | |||
| pre_sels.append(seli) | |||
| pre_aggs.append(pre_final_sql['agg'][idx]) | |||
| pre_final_sql['sel'] = pre_sels | |||
| pre_final_sql['agg'] = pre_aggs | |||
| return pre_final_sql | |||
| elif action == 'change_focus_total': | |||
| pre_final_sql = history_sql | |||
| pre_sels = current_sql['sel'] | |||
| pre_aggs = current_sql['agg'] | |||
| pre_final_sql['sel'] = pre_sels | |||
| pre_final_sql['agg'] = pre_aggs | |||
| for pre_condi in current_sql['conds']: | |||
| if pre_condi[0] < len(headers): | |||
| in_flag = False | |||
| for history_condi in history_sql['conds']: | |||
| if pre_condi[0] == history_condi[0]: | |||
| in_flag = True | |||
| if not in_flag: | |||
| pre_final_sql['conds'].append(pre_condi) | |||
| return pre_final_sql | |||
| elif action == 'del_cond': | |||
| pre_final_sql = copy.deepcopy(history_sql) | |||
| final_conds = [] | |||
| for idx, condi in enumerate(pre_final_sql['conds']): | |||
| if condi[0] not in current_sql['sel']: | |||
| final_conds.append(condi) | |||
| pre_final_sql['conds'] = final_conds | |||
| final_conds = [] | |||
| for condi in pre_final_sql['conds']: | |||
| if condi[0] < len(headers): | |||
| final_conds.append(condi) | |||
| if len(final_conds) < 1: | |||
| final_conds.append([len(headers), 2, 'Null']) | |||
| pre_final_sql['conds'] = final_conds | |||
| return pre_final_sql | |||
| elif action == 'change_cond': | |||
| pre_final_sql = history_sql | |||
| final_conds = [] | |||
| for idx, condi in enumerate(pre_final_sql['conds']): | |||
| in_single_flag = False | |||
| for single_condi in current_sql['conds']: | |||
| if condi[0] == single_condi[0]: | |||
| in_single_flag = True | |||
| final_conds.append(single_condi) | |||
| if not in_single_flag: | |||
| final_conds.append(condi) | |||
| pre_final_sql['conds'] = final_conds | |||
| final_conds = [] | |||
| for condi in pre_final_sql['conds']: | |||
| if condi[0] < len(headers): | |||
| final_conds.append(condi) | |||
| if len(final_conds) < 1: | |||
| final_conds.append([len(headers), 2, 'Null', 'Null']) | |||
| pre_final_sql['conds'] = final_conds | |||
| return pre_final_sql | |||
| elif action == 'add_cond': | |||
| pre_final_sql = history_sql | |||
| final_conds = pre_final_sql['conds'] | |||
| for idx, condi in enumerate(current_sql['conds']): | |||
| if condi[0] < len(headers): | |||
| final_conds.append(condi) | |||
| pre_final_sql['conds'] = final_conds | |||
| final_conds = [] | |||
| for condi in pre_final_sql['conds']: | |||
| if condi[0] < len(headers): | |||
| final_conds.append(condi) | |||
| if len(final_conds) < 1: | |||
| final_conds.append([len(headers), 2, 'Null']) | |||
| pre_final_sql['conds'] = final_conds | |||
| return pre_final_sql | |||
| else: | |||
| return current_sql | |||
| def sql_dict_to_str(self, result, table): | |||
| """ | |||
| convert sql struct to string | |||
| """ | |||
| header_names = table['header_name'] + ['空列'] | |||
| header_ids = table['header_id'] + ['null'] | |||
| sql = result['sql'] | |||
| str_sel_list, sql_sel_list = [], [] | |||
| for idx, sel in enumerate(sql['sel']): | |||
| header_name = header_names[sel] | |||
| header_id = '`%s`.`%s`' % (table['table_id'], header_ids[sel]) | |||
| if sql['agg'][idx] == 0: | |||
| str_sel_list.append(header_name) | |||
| sql_sel_list.append(header_id) | |||
| else: | |||
| str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '( ' | |||
| + header_name + ' )') | |||
| sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '( ' | |||
| + header_id + ' )') | |||
| str_cond_list, sql_cond_list = [], [] | |||
| for cond in sql['conds']: | |||
| header_name = header_names[cond[0]] | |||
| header_id = '`%s`.`%s`' % (table['table_id'], header_ids[cond[0]]) | |||
| op = self.cond_ops[cond[1]] | |||
| value = cond[2] | |||
| str_cond_list.append('( ' + header_name + ' ' + op + ' "' + value | |||
| + '" )') | |||
| sql_cond_list.append('( ' + header_id + ' ' + op + ' "' + value | |||
| + '" )') | |||
| cond = ' ' + self.cond_conn_ops[sql['cond_conn_op']] + ' ' | |||
| final_str = 'SELECT %s FROM %s WHERE %s' % (', '.join(str_sel_list), | |||
| table['table_name'], | |||
| cond.join(str_cond_list)) | |||
| final_sql = 'SELECT %s FROM `%s` WHERE %s' % (', '.join(sql_sel_list), | |||
| table['table_id'], | |||
| cond.join(sql_cond_list)) | |||
| sql = SQLQuery( | |||
| string=final_str, query=final_sql, sql_result=result['sql']) | |||
| return sql | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| result = inputs['result'] | |||
| history_sql = inputs['history_sql'] | |||
| result['sql'] = self.post_process_multi_turn( | |||
| history_sql=history_sql, | |||
| result=result, | |||
| table=self.db.tables[result['table_id']]) | |||
| sql = self.sql_dict_to_str( | |||
| result=result, table=self.db.tables[result['table_id']]) | |||
| output = {OutputKeys.OUTPUT: sql, OutputKeys.HISTORY: result['sql']} | |||
| return output | |||
| def _collate_fn(self, data): | |||
| return data | |||
| @@ -30,6 +30,7 @@ if TYPE_CHECKING: | |||
| DialogStateTrackingPreprocessor) | |||
| from .video import ReadVideoData, MovieSceneSegmentationPreprocessor | |||
| from .star import ConversationalTextToSqlPreprocessor | |||
| from .star3 import TableQuestionAnsweringPreprocessor | |||
| else: | |||
| _import_structure = { | |||
| @@ -62,6 +63,7 @@ else: | |||
| 'DialogStateTrackingPreprocessor', 'InputFeatures' | |||
| ], | |||
| 'star': ['ConversationalTextToSqlPreprocessor'], | |||
| 'star3': ['TableQuestionAnsweringPreprocessor'], | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,24 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .table_question_answering_preprocessor import TableQuestionAnsweringPreprocessor | |||
| from .fields import MultiWOZBPETextField, IntentBPETextField | |||
| else: | |||
| _import_structure = { | |||
| 'table_question_answering_preprocessor': | |||
| ['TableQuestionAnsweringPreprocessor'], | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,77 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import json | |||
| import tqdm | |||
| from modelscope.preprocessors.star3.fields.struct import Trie | |||
| class Database: | |||
| def __init__(self, tokenizer, table_file_path, syn_dict_file_path): | |||
| self.tokenizer = tokenizer | |||
| self.tables = self.init_tables(table_file_path=table_file_path) | |||
| self.syn_dict = self.init_syn_dict( | |||
| syn_dict_file_path=syn_dict_file_path) | |||
| def init_tables(self, table_file_path): | |||
| tables = {} | |||
| lines = [] | |||
| with open(table_file_path, 'r') as fo: | |||
| for line in fo: | |||
| lines.append(line) | |||
| for line in tqdm.tqdm(lines, desc='Load Tables'): | |||
| table = json.loads(line.strip()) | |||
| table_header_length = 0 | |||
| headers_tokens = [] | |||
| for header in table['header_name']: | |||
| header_tokens = self.tokenizer.tokenize(header) | |||
| table_header_length += len(header_tokens) | |||
| headers_tokens.append(header_tokens) | |||
| empty_column = self.tokenizer.tokenize('空列') | |||
| table_header_length += len(empty_column) | |||
| headers_tokens.append(empty_column) | |||
| table['tablelen'] = table_header_length | |||
| table['header_tok'] = headers_tokens | |||
| table['header_types'].append('null') | |||
| table['header_units'] = [ | |||
| self.tokenizer.tokenize(unit) for unit in table['header_units'] | |||
| ] + [[]] | |||
| trie_set = [Trie() for _ in table['header_name']] | |||
| for row in table['rows']: | |||
| for ii, cell in enumerate(row): | |||
| if 'real' in table['header_types'][ii].lower() or \ | |||
| 'number' in table['header_types'][ii].lower() or \ | |||
| 'duration' in table['header_types'][ii].lower(): | |||
| continue | |||
| word = str(cell).strip().lower() | |||
| trie_set[ii].insert(word, word) | |||
| table['value_trie'] = trie_set | |||
| tables[table['table_id']] = table | |||
| return tables | |||
| def init_syn_dict(self, syn_dict_file_path): | |||
| lines = [] | |||
| with open(syn_dict_file_path, encoding='utf-8') as fo: | |||
| for line in fo: | |||
| lines.append(line) | |||
| syn_dict = {} | |||
| for line in tqdm.tqdm(lines, desc='Load Synonym Dict'): | |||
| tokens = line.strip().split('\t') | |||
| if len(tokens) != 2: | |||
| continue | |||
| keys = tokens[0].strip().split('|') | |||
| values = tokens[1].strip().split('|') | |||
| for key in keys: | |||
| key = key.lower().strip() | |||
| syn_dict.setdefault(key, []) | |||
| for value in values: | |||
| syn_dict[key].append(value.lower().strip()) | |||
| return syn_dict | |||
| @@ -0,0 +1,423 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import re | |||
| from modelscope.preprocessors.star3.fields.struct import TypeInfo | |||
| class SchemaLinker: | |||
| def __init__(self): | |||
| pass | |||
| def find_in_list(self, comlist, words): | |||
| result = False | |||
| for com in comlist: | |||
| if words in com: | |||
| result = True | |||
| break | |||
| return result | |||
| def get_continue_score(self, pstr, tstr): | |||
| comlist = [] | |||
| minlen = min(len(pstr), len(tstr)) | |||
| for slen in range(minlen, 1, -1): | |||
| for ts in range(0, len(tstr), 1): | |||
| if ts + slen > len(tstr): | |||
| continue | |||
| words = tstr[ts:ts + slen] | |||
| if words in pstr and not self.find_in_list(comlist, words): | |||
| comlist.append(words) | |||
| comlen = 0 | |||
| for com in comlist: | |||
| comlen += len(com) * len(com) | |||
| weight = comlen / (len(tstr) * len(tstr) + 0.001) | |||
| if weight > 1.0: | |||
| weight = 1.0 | |||
| return weight | |||
| def get_match_score(self, ptokens, ttokens): | |||
| pset = set(ptokens) | |||
| tset = set(ttokens) | |||
| comset = pset & tset | |||
| allset = pset | tset | |||
| weight2 = len(comset) / (len(allset) + 0.001) | |||
| weight3 = self.get_continue_score(''.join(ptokens), ''.join(ttokens)) | |||
| return 0.4 * weight2 + 0.6 * weight3 | |||
| def is_number(self, s): | |||
| try: | |||
| float(s) | |||
| return True | |||
| except ValueError: | |||
| pass | |||
| try: | |||
| import unicodedata | |||
| unicodedata.numeric(s) | |||
| return True | |||
| except (TypeError, ValueError): | |||
| pass | |||
| return False | |||
| def get_match_phrase(self, query, target): | |||
| if target in query: | |||
| return target, 1.0 | |||
| qtokens = [] | |||
| for i in range(0, len(query), 1): | |||
| qtokens.append(query[i:i + 1]) | |||
| ttokens = [] | |||
| for i in range(0, len(target), 1): | |||
| ttokens.append(target[i:i + 1]) | |||
| ttok_set = set(ttokens) | |||
| phrase = '' | |||
| score = 0.0 | |||
| for qidx, qword in enumerate(qtokens): | |||
| if qword not in ttok_set: | |||
| continue | |||
| eidx = (qidx + 2 * len(ttokens)) if ( | |||
| len(qtokens) > qidx + 2 * len(ttokens)) else len(qtokens) | |||
| while eidx > qidx: | |||
| ptokens = qtokens[qidx:eidx] | |||
| weight = self.get_match_score(ptokens, ttokens) | |||
| if weight + 0.001 > score: | |||
| score = weight | |||
| phrase = ''.join(ptokens) | |||
| eidx -= 1 | |||
| if self.is_number(target) and phrase != target: | |||
| score = 0.0 | |||
| if len(phrase) > 1 and phrase in target: | |||
| score *= (1.0 + 0.05 * len(phrase)) | |||
| return phrase, score | |||
| def allfindpairidx(self, que_tok, value_tok, weight): | |||
| idxs = [] | |||
| for i in range(0, len(que_tok) - len(value_tok) + 1, 1): | |||
| s = i | |||
| e = i | |||
| matched = True | |||
| for j in range(0, len(value_tok), 1): | |||
| if value_tok[j].lower() == que_tok[i + j].lower(): | |||
| e = i + j | |||
| else: | |||
| matched = False | |||
| break | |||
| if matched: | |||
| idxs.append([s, e, weight]) | |||
| return idxs | |||
| def findnear(self, ps1, pe1, ps2, pe2): | |||
| if abs(ps1 - pe2) <= 2 or abs(pe1 - ps2) <= 2: | |||
| return True | |||
| return False | |||
| def get_column_type(self, col_idx, table): | |||
| colType = table['header_types'][col_idx] | |||
| if 'number' in colType or 'duration' in colType or 'real' in colType: | |||
| colType = 'real' | |||
| elif 'date' in colType: | |||
| colType = 'date' | |||
| elif 'bool' in colType: | |||
| colType = 'bool' | |||
| else: | |||
| colType = 'text' | |||
| return colType | |||
| def add_type_all(self, typeinfos, index, idxs, label, linktype, value, | |||
| orgvalue): | |||
| for idx in idxs: | |||
| info = TypeInfo(label, index, linktype, value, orgvalue, idx[0], | |||
| idx[1], idx[2]) | |||
| flag = True | |||
| for i, typeinfo in enumerate(typeinfos): | |||
| if info.pstart < typeinfo.pstart: | |||
| typeinfos.insert(i, info) | |||
| flag = False | |||
| break | |||
| if flag: | |||
| typeinfos.append(info) | |||
| return typeinfos | |||
| def save_info(self, tinfo, sinfo): | |||
| flag = True | |||
| if tinfo.pstart > sinfo.pend or tinfo.pend < sinfo.pstart: | |||
| pass | |||
| elif tinfo.pstart >= sinfo.pstart and \ | |||
| tinfo.pend <= sinfo.pend and tinfo.index == -1: | |||
| flag = False | |||
| elif tinfo.pstart == sinfo.pstart and sinfo.pend == tinfo.pend and \ | |||
| abs(tinfo.weight - sinfo.weight) < 0.01: | |||
| pass | |||
| else: | |||
| if sinfo.label == 'col' or sinfo.label == 'val': | |||
| if tinfo.label == 'col' or tinfo.label == 'val': | |||
| if (sinfo.pend | |||
| - sinfo.pstart) > (tinfo.pend - tinfo.pstart) or ( | |||
| sinfo.weight > tinfo.weight | |||
| and sinfo.index != -1): | |||
| flag = False | |||
| else: | |||
| flag = False | |||
| else: | |||
| if (tinfo.label == 'op' or tinfo.label == 'agg'): | |||
| if (sinfo.pend - sinfo.pstart) > ( | |||
| tinfo.pend | |||
| - tinfo.pstart) or sinfo.weight > tinfo.weight: | |||
| flag = False | |||
| return flag | |||
| def normal_type_infos(self, infos): | |||
| typeinfos = [] | |||
| for info in infos: | |||
| typeinfos = [x for x in typeinfos if self.save_info(x, info)] | |||
| flag = True | |||
| for i, typeinfo in enumerate(typeinfos): | |||
| if not self.save_info(info, typeinfo): | |||
| flag = False | |||
| break | |||
| if info.pstart < typeinfo.pstart: | |||
| typeinfos.insert(i, info) | |||
| flag = False | |||
| break | |||
| if flag: | |||
| typeinfos.append(info) | |||
| return typeinfos | |||
| def findnear_typeinfo(self, info1, info2): | |||
| return self.findnear(info1.pstart, info1.pend, info2.pstart, | |||
| info2.pend) | |||
| def find_real_column(self, infos, table): | |||
| for i, vinfo in enumerate(infos): | |||
| if vinfo.index != -1 or vinfo.label != 'val': | |||
| continue | |||
| eoidx = -1 | |||
| for j, oinfo in enumerate(infos): | |||
| if oinfo.label != 'op': | |||
| continue | |||
| if self.findnear_typeinfo(vinfo, oinfo): | |||
| eoidx = j | |||
| break | |||
| for j, cinfo in enumerate(infos): | |||
| if cinfo.label != 'col' or table['header_types'][ | |||
| cinfo.index] != 'real': | |||
| continue | |||
| if self.findnear_typeinfo(cinfo, vinfo) or ( | |||
| eoidx != -1 | |||
| and self.findnear_typeinfo(cinfo, infos[eoidx])): | |||
| infos[i].index = cinfo.index | |||
| break | |||
| return infos | |||
| def filter_column_infos(self, infos): | |||
| delid = [] | |||
| for i, info in enumerate(infos): | |||
| if info.label != 'col': | |||
| continue | |||
| for j in range(i + 1, len(infos), 1): | |||
| if infos[j].label == 'col' and \ | |||
| info.pstart == infos[j].pstart and \ | |||
| info.pend == infos[j].pend: | |||
| delid.append(i) | |||
| delid.append(j) | |||
| break | |||
| typeinfos = [] | |||
| for idx, info in enumerate(infos): | |||
| if idx in set(delid): | |||
| continue | |||
| typeinfos.append(info) | |||
| return typeinfos | |||
| def filter_type_infos(self, infos, table): | |||
| infos = self.filter_column_infos(infos) | |||
| infos = self.find_real_column(infos, table) | |||
| colvalMp = {} | |||
| for info in infos: | |||
| if info.label == 'col': | |||
| colvalMp[info.index] = [] | |||
| for info in infos: | |||
| if info.label == 'val' and info.index in colvalMp: | |||
| colvalMp[info.index].append(info) | |||
| delid = [] | |||
| for idx, info in enumerate(infos): | |||
| if info.label != 'val' or info.index in colvalMp: | |||
| continue | |||
| for index in colvalMp.keys(): | |||
| valinfos = colvalMp[index] | |||
| for valinfo in valinfos: | |||
| if valinfo.pstart <= info.pstart and \ | |||
| valinfo.pend >= info.pend: | |||
| delid.append(idx) | |||
| break | |||
| typeinfos = [] | |||
| for idx, info in enumerate(infos): | |||
| if idx in set(delid): | |||
| continue | |||
| typeinfos.append(info) | |||
| return typeinfos | |||
| def get_table_match_score(self, nlu_t, schema_link): | |||
| match_len = 0 | |||
| for info in schema_link: | |||
| scale = 0.6 | |||
| if info['question_len'] > 0 and info['column_index'] != -1: | |||
| scale = 1.0 | |||
| else: | |||
| scale = 0.5 | |||
| match_len += scale * info['question_len'] * info['weight'] | |||
| return match_len / (len(nlu_t) + 0.1) | |||
| def get_entity_linking(self, tokenizer, nlu, nlu_t, tables, col_syn_dict): | |||
| """ | |||
| get linking between question and schema column | |||
| """ | |||
| typeinfos = [] | |||
| numbers = re.findall(r'[-]?\d*\.\d+|[-]?\d+|\d+', nlu) | |||
| # search schema link in every table | |||
| search_result_list = [] | |||
| for tablename in tables: | |||
| table = tables[tablename] | |||
| trie_set = None | |||
| if 'value_trie' in table: | |||
| trie_set = table['value_trie'] | |||
| typeinfos = [] | |||
| for ii, column in enumerate(table['header_name']): | |||
| column = column.lower() | |||
| column_new = re.sub('(.*?)', '', column) | |||
| column_new = re.sub('(.*?)', '', column_new) | |||
| cphrase, cscore = self.get_match_phrase( | |||
| nlu.lower(), column_new) | |||
| if cscore > 0.3 and cphrase.strip() != '': | |||
| phrase_tok = tokenizer.tokenize(cphrase) | |||
| cidxs = self.allfindpairidx(nlu_t, phrase_tok, cscore) | |||
| typeinfos = self.add_type_all(typeinfos, ii, cidxs, 'col', | |||
| 'column', cphrase, column) | |||
| if cscore < 0.8 and column_new in col_syn_dict: | |||
| columns = list(set(col_syn_dict[column_new])) | |||
| for syn_col in columns: | |||
| if syn_col not in nlu.lower() or syn_col == '': | |||
| continue | |||
| phrase_tok = tokenizer.tokenize(syn_col) | |||
| cidxs = self.allfindpairidx(nlu_t, phrase_tok, 1.0) | |||
| typeinfos = self.add_type_all(typeinfos, ii, cidxs, | |||
| 'col', 'column', syn_col, | |||
| column) | |||
| for ii, trie in enumerate(trie_set): | |||
| ans = trie.match(nlu.lower()) | |||
| for cell in ans.keys(): | |||
| vphrase = cell | |||
| vscore = 1.0 | |||
| # print("trie_set find:", cell, ans[cell]) | |||
| phrase_tok = tokenizer.tokenize(vphrase) | |||
| if len(phrase_tok) == 0 or len(vphrase) < 2: | |||
| continue | |||
| vidxs = self.allfindpairidx(nlu_t, phrase_tok, vscore) | |||
| linktype = self.get_column_type(ii, table) | |||
| typeinfos = self.add_type_all(typeinfos, ii, vidxs, 'val', | |||
| linktype, vphrase, ans[cell]) | |||
| for number in set(numbers): | |||
| number_tok = tokenizer.tokenize(number.lower()) | |||
| if len(number_tok) == 0: | |||
| continue | |||
| nidxs = self.allfindpairidx(nlu_t, number_tok, 1.0) | |||
| typeinfos = self.add_type_all(typeinfos, -1, nidxs, 'val', | |||
| 'real', number, number) | |||
| newtypeinfos = self.normal_type_infos(typeinfos) | |||
| newtypeinfos = self.filter_type_infos(newtypeinfos, table) | |||
| final_question = [0] * len(nlu_t) | |||
| final_header = [0] * len(table['header_name']) | |||
| for typeinfo in newtypeinfos: | |||
| pstart = typeinfo.pstart | |||
| pend = typeinfo.pend + 1 | |||
| if typeinfo.label == 'op' or typeinfo.label == 'agg': | |||
| score = int(typeinfo.linktype[-1]) | |||
| if typeinfo.label == 'op': | |||
| score += 6 | |||
| else: | |||
| score += 11 | |||
| for i in range(pstart, pend, 1): | |||
| final_question[i] = score | |||
| elif typeinfo.label == 'col': | |||
| for i in range(pstart, pend, 1): | |||
| final_question[i] = 4 | |||
| if final_header[typeinfo.index] % 2 == 0: | |||
| final_header[typeinfo.index] += 1 | |||
| elif typeinfo.label == 'val': | |||
| if typeinfo.index == -1: | |||
| for i in range(pstart, pend, 1): | |||
| final_question[i] = 5 | |||
| else: | |||
| for i in range(pstart, pend, 1): | |||
| final_question[i] = 2 | |||
| final_question[pstart] = 1 | |||
| final_question[pend - 1] = 3 | |||
| if final_header[typeinfo.index] < 2: | |||
| final_header[typeinfo.index] += 2 | |||
| # collect schema_link | |||
| schema_link = [] | |||
| for sl in newtypeinfos: | |||
| if sl.label in ['val', 'col']: | |||
| schema_link.append({ | |||
| 'question_len': | |||
| max(0, sl.pend - sl.pstart + 1), | |||
| 'question_index': [sl.pstart, sl.pend], | |||
| 'question_span': | |||
| ''.join(nlu_t[sl.pstart:sl.pend + 1]), | |||
| 'column_index': | |||
| sl.index, | |||
| 'column_span': | |||
| table['header_name'][sl.index] | |||
| if sl.index != -1 else '空列', | |||
| 'label': | |||
| sl.label, | |||
| 'weight': | |||
| round(sl.weight, 4) | |||
| }) | |||
| # get the match score of each table | |||
| match_score = self.get_table_match_score(nlu_t, schema_link) | |||
| search_result = { | |||
| 'table_id': table['table_id'], | |||
| 'question_knowledge': final_question, | |||
| 'header_knowledge': final_header, | |||
| 'schema_link': schema_link, | |||
| 'match_score': match_score | |||
| } | |||
| search_result_list.append(search_result) | |||
| search_result_list = sorted( | |||
| search_result_list, key=lambda x: x['match_score'], | |||
| reverse=True)[0:4] | |||
| return search_result_list | |||
| @@ -0,0 +1,181 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| cond_ops = ['>', '<', '==', '!=', 'ASC', 'DESC'] | |||
| agg_ops = [ | |||
| '', 'AVG', 'MAX', 'MIN', 'COUNT', 'SUM', 'COMPARE', 'GROUP BY', 'SAME' | |||
| ] | |||
| conn_ops = ['', 'AND', 'OR'] | |||
| class Context: | |||
| def __init__(self): | |||
| self.history_sql = None | |||
| def set_history_sql(self, sql): | |||
| self.history_sql = sql | |||
| class SQLQuery: | |||
| def __init__(self, string, query, sql_result): | |||
| self.string = string | |||
| self.query = query | |||
| self.sql_result = sql_result | |||
| class TrieNode(object): | |||
| def __init__(self): | |||
| """ | |||
| Initialize your data structure here. | |||
| """ | |||
| self.data = {} | |||
| self.is_word = False | |||
| self.term = None | |||
| class Trie(object): | |||
| def __init__(self): | |||
| self.root = TrieNode() | |||
| def insert(self, word, term): | |||
| """ | |||
| Inserts a word into the trie. | |||
| :type word: str | |||
| :rtype: void | |||
| """ | |||
| node = self.root | |||
| for letter in word: | |||
| child = node.data.get(letter) | |||
| if not child: | |||
| node.data[letter] = TrieNode() | |||
| node = node.data[letter] | |||
| node.is_word = True | |||
| node.term = term | |||
| def search(self, word): | |||
| """ | |||
| Returns if the word is in the trie. | |||
| :type word: str | |||
| :rtype: bool | |||
| """ | |||
| node = self.root | |||
| for letter in word: | |||
| node = node.data.get(letter) | |||
| if not node: | |||
| return None, False | |||
| return node.term, True | |||
| def match(self, query): | |||
| start = 0 | |||
| end = 1 | |||
| length = len(query) | |||
| ans = {} | |||
| while start < length and end < length: | |||
| sub = query[start:end] | |||
| term, flag = self.search(sub) | |||
| if flag: | |||
| if term is not None: | |||
| ans[sub] = term | |||
| end += 1 | |||
| else: | |||
| start += 1 | |||
| end = start + 1 | |||
| return ans | |||
| def starts_with(self, prefix): | |||
| """ | |||
| Returns if there is any word in the trie | |||
| that starts with the given prefix. | |||
| :type prefix: str | |||
| :rtype: bool | |||
| """ | |||
| node = self.root | |||
| for letter in prefix: | |||
| node = node.data.get(letter) | |||
| if not node: | |||
| return False | |||
| return True | |||
| def get_start(self, prefix): | |||
| """ | |||
| Returns words started with prefix | |||
| :param prefix: | |||
| :return: words (list) | |||
| """ | |||
| def _get_key(pre, pre_node): | |||
| words_list = [] | |||
| if pre_node.is_word: | |||
| words_list.append(pre) | |||
| for x in pre_node.data.keys(): | |||
| words_list.extend(_get_key(pre + str(x), pre_node.data.get(x))) | |||
| return words_list | |||
| words = [] | |||
| if not self.starts_with(prefix): | |||
| return words | |||
| if self.search(prefix): | |||
| words.append(prefix) | |||
| return words | |||
| node = self.root | |||
| for letter in prefix: | |||
| node = node.data.get(letter) | |||
| return _get_key(prefix, node) | |||
| class TypeInfo: | |||
| def __init__(self, label, index, linktype, value, orgvalue, pstart, pend, | |||
| weight): | |||
| self.label = label | |||
| self.index = index | |||
| self.linktype = linktype | |||
| self.value = value | |||
| self.orgvalue = orgvalue | |||
| self.pstart = pstart | |||
| self.pend = pend | |||
| self.weight = weight | |||
| class Constant: | |||
| def __init__(self): | |||
| self.action_ops = [ | |||
| 'add_cond', 'change_cond', 'del_cond', 'change_focus_total', | |||
| 'change_agg_only', 'del_focus', 'restart', 'switch_table', | |||
| 'out_of_scripts', 'repeat', 'firstTurn' | |||
| ] | |||
| self.agg_ops = [ | |||
| '', 'AVG', 'MAX', 'MIN', 'COUNT', 'SUM', 'COMPARE', 'GROUP BY', | |||
| 'SAME' | |||
| ] | |||
| self.cond_ops = ['>', '<', '==', '!=', 'ASC', 'DESC'] | |||
| self.cond_conn_ops = ['', 'AND', 'OR'] | |||
| self.col_type_dict = { | |||
| 'null': 0, | |||
| 'text': 1, | |||
| 'number': 2, | |||
| 'duration': 3, | |||
| 'bool': 4, | |||
| 'date': 5 | |||
| } | |||
| self.schema_link_dict = { | |||
| 'col_start': 1, | |||
| 'col_middle': 2, | |||
| 'col_end': 3, | |||
| 'val_start': 4, | |||
| 'val_middle': 5, | |||
| 'val_end': 6 | |||
| } | |||
| self.max_select_num = 4 | |||
| self.max_where_num = 6 | |||
| @@ -0,0 +1,118 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from typing import Any, Dict | |||
| import torch | |||
| from transformers import BertTokenizer | |||
| from modelscope.metainfo import Preprocessors | |||
| from modelscope.preprocessors.base import Preprocessor | |||
| from modelscope.preprocessors.builder import PREPROCESSORS | |||
| from modelscope.preprocessors.star3.fields.database import Database | |||
| from modelscope.preprocessors.star3.fields.schema_link import SchemaLinker | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import Fields, ModelFile | |||
| from modelscope.utils.type_assert import type_assert | |||
| __all__ = ['TableQuestionAnsweringPreprocessor'] | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, | |||
| module_name=Preprocessors.table_question_answering_preprocessor) | |||
| class TableQuestionAnsweringPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, db: Database = None, *args, **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| model_dir (str): model path | |||
| db (Database): database instance | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| self.model_dir: str = model_dir | |||
| self.config = Config.from_file( | |||
| os.path.join(self.model_dir, ModelFile.CONFIGURATION)) | |||
| # read tokenizer | |||
| self.tokenizer = BertTokenizer( | |||
| os.path.join(self.model_dir, ModelFile.VOCAB_FILE)) | |||
| # read database | |||
| if db is None: | |||
| self.db = Database( | |||
| tokenizer=self.tokenizer, | |||
| table_file_path=os.path.join(self.model_dir, 'table.json'), | |||
| syn_dict_file_path=os.path.join(self.model_dir, 'synonym.txt')) | |||
| else: | |||
| self.db = db | |||
| # get schema linker | |||
| self.schema_linker = SchemaLinker() | |||
| # set device | |||
| self.device = 'cuda' if \ | |||
| ('device' not in kwargs or kwargs['device'] == 'gpu') \ | |||
| and torch.cuda.is_available() else 'cpu' | |||
| def construct_data(self, search_result_list, nlu, nlu_t, db, history_sql): | |||
| datas = [] | |||
| for search_result in search_result_list: | |||
| data = {} | |||
| data['table_id'] = search_result['table_id'] | |||
| data['question'] = nlu | |||
| data['question_tok'] = nlu_t | |||
| data['header_tok'] = db.tables[data['table_id']]['header_tok'] | |||
| data['types'] = db.tables[data['table_id']]['header_types'] | |||
| data['units'] = db.tables[data['table_id']]['header_units'] | |||
| data['action'] = 0 | |||
| data['sql'] = None | |||
| data['history_sql'] = history_sql | |||
| data['wvi_corenlp'] = [] | |||
| data['bertindex_knowledge'] = search_result['question_knowledge'] | |||
| data['header_knowledge'] = search_result['header_knowledge'] | |||
| data['schema_link'] = search_result['schema_link'] | |||
| datas.append(data) | |||
| return datas | |||
| @type_assert(object, dict) | |||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data (dict): | |||
| utterance: a sentence | |||
| last_sql: predicted sql of last utterance | |||
| Example: | |||
| utterance: 'Which of these are hiring?' | |||
| last_sql: '' | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| # tokenize question | |||
| question = data['question'] | |||
| history_sql = data['history_sql'] | |||
| nlu = question.lower() | |||
| nlu_t = self.tokenizer.tokenize(nlu) | |||
| # get linking | |||
| search_result_list = self.schema_linker.get_entity_linking( | |||
| tokenizer=self.tokenizer, | |||
| nlu=nlu, | |||
| nlu_t=nlu_t, | |||
| tables=self.db.tables, | |||
| col_syn_dict=self.db.syn_dict) | |||
| # collect data | |||
| datas = self.construct_data( | |||
| search_result_list=search_result_list[0:1], | |||
| nlu=nlu, | |||
| nlu_t=nlu_t, | |||
| db=self.db, | |||
| history_sql=history_sql) | |||
| return {'datas': datas} | |||
| @@ -3,7 +3,8 @@ from typing import List | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.nlp import (ConversationalTextToSqlPipeline, | |||
| DialogStateTrackingPipeline) | |||
| DialogStateTrackingPipeline, | |||
| TableQuestionAnsweringPipeline) | |||
| def text2sql_tracking_and_print_results( | |||
| @@ -42,3 +43,17 @@ def tracking_and_print_dialog_states( | |||
| print(json.dumps(result)) | |||
| history_states.extend([result[OutputKeys.OUTPUT], {}]) | |||
| def tableqa_tracking_and_print_results( | |||
| test_case, pipelines: List[TableQuestionAnsweringPipeline]): | |||
| for pipeline in pipelines: | |||
| historical_queries = None | |||
| for question in test_case['utterance']: | |||
| output_dict = pipeline({ | |||
| 'question': question, | |||
| 'history_sql': historical_queries | |||
| }) | |||
| print('output_dict', output_dict['output'].string, | |||
| output_dict['output'].query) | |||
| historical_queries = output_dict['history'] | |||
| @@ -0,0 +1,76 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import unittest | |||
| from typing import List | |||
| from transformers import BertTokenizer | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline | |||
| from modelscope.preprocessors import TableQuestionAnsweringPreprocessor | |||
| from modelscope.preprocessors.star3.fields.database import Database | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.nlp.nlp_utils import tableqa_tracking_and_print_results | |||
| from modelscope.utils.test_utils import test_level | |||
| class TableQuestionAnswering(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.task = Tasks.table_question_answering | |||
| self.model_id = 'damo/nlp_convai_text2sql_pretrain_cn' | |||
| model_id = 'damo/nlp_convai_text2sql_pretrain_cn' | |||
| test_case = { | |||
| 'utterance': | |||
| ['长江流域的小(2)型水库的库容总量是多少?', '那平均值是多少?', '那水库的名称呢?', '换成中型的呢?'] | |||
| } | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_by_direct_model_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| preprocessor = TableQuestionAnsweringPreprocessor(model_dir=cache_path) | |||
| pipelines = [ | |||
| TableQuestionAnsweringPipeline( | |||
| model=cache_path, preprocessor=preprocessor) | |||
| ] | |||
| tableqa_tracking_and_print_results(self.test_case, pipelines) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| preprocessor = TableQuestionAnsweringPreprocessor( | |||
| model_dir=model.model_dir) | |||
| pipelines = [ | |||
| TableQuestionAnsweringPipeline( | |||
| model=model, preprocessor=preprocessor) | |||
| ] | |||
| tableqa_tracking_and_print_results(self.test_case, pipelines) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_task(self): | |||
| pipelines = [pipeline(Tasks.table_question_answering, self.model_id)] | |||
| tableqa_tracking_and_print_results(self.test_case, pipelines) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub_with_other_classes(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| self.tokenizer = BertTokenizer( | |||
| os.path.join(model.model_dir, ModelFile.VOCAB_FILE)) | |||
| db = Database( | |||
| tokenizer=self.tokenizer, | |||
| table_file_path=os.path.join(model.model_dir, 'table.json'), | |||
| syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt')) | |||
| preprocessor = TableQuestionAnsweringPreprocessor( | |||
| model_dir=model.model_dir, db=db) | |||
| pipelines = [ | |||
| TableQuestionAnsweringPipeline( | |||
| model=model, preprocessor=preprocessor, db=db) | |||
| ] | |||
| tableqa_tracking_and_print_results(self.test_case, pipelines) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||