From 77cfcf0a9acbbfb5f122a65cb4ce235944596146 Mon Sep 17 00:00:00 2001 From: "caorongyu.cry" Date: Wed, 14 Sep 2022 19:04:56 +0800 Subject: [PATCH] [to #42322933] commit nlp_convai_text2sql_pretrain_cn inference process to modelscope commit nlp_convai_text2sql_pretrain_cn inference process to modelscope Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10025155 --- modelscope/metainfo.py | 3 + modelscope/models/nlp/__init__.py | 2 + modelscope/models/nlp/star3/__init__.py | 0 .../models/nlp/star3/configuration_star3.py | 128 +++ modelscope/models/nlp/star3/modeling_star3.py | 1023 +++++++++++++++++ .../models/nlp/table_question_answering.py | 747 ++++++++++++ modelscope/outputs.py | 8 + modelscope/pipelines/builder.py | 3 + modelscope/pipelines/nlp/__init__.py | 3 + .../nlp/table_question_answering_pipeline.py | 284 +++++ modelscope/preprocessors/__init__.py | 2 + modelscope/preprocessors/star3/__init__.py | 24 + .../preprocessors/star3/fields/__init__.py | 0 .../preprocessors/star3/fields/database.py | 77 ++ .../preprocessors/star3/fields/schema_link.py | 423 +++++++ .../preprocessors/star3/fields/struct.py | 181 +++ .../table_question_answering_preprocessor.py | 118 ++ modelscope/utils/nlp/nlp_utils.py | 17 +- .../test_table_question_answering.py | 76 ++ 19 files changed, 3118 insertions(+), 1 deletion(-) create mode 100644 modelscope/models/nlp/star3/__init__.py create mode 100644 modelscope/models/nlp/star3/configuration_star3.py create mode 100644 modelscope/models/nlp/star3/modeling_star3.py create mode 100644 modelscope/models/nlp/table_question_answering.py create mode 100644 modelscope/pipelines/nlp/table_question_answering_pipeline.py create mode 100644 modelscope/preprocessors/star3/__init__.py create mode 100644 modelscope/preprocessors/star3/fields/__init__.py create mode 100644 modelscope/preprocessors/star3/fields/database.py create mode 100644 modelscope/preprocessors/star3/fields/schema_link.py create mode 100644 modelscope/preprocessors/star3/fields/struct.py create mode 100644 modelscope/preprocessors/star3/table_question_answering_preprocessor.py create mode 100644 tests/pipelines/test_table_question_answering.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index e5c3873b..80a522b2 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -55,6 +55,7 @@ class Models(object): space_intent = 'space-intent' space_modeling = 'space-modeling' star = 'star' + star3 = 'star3' tcrf = 'transformer-crf' transformer_softmax = 'transformer-softmax' lcrf = 'lstm-crf' @@ -193,6 +194,7 @@ class Pipelines(object): plug_generation = 'plug-generation' faq_question_answering = 'faq-question-answering' conversational_text_to_sql = 'conversational-text-to-sql' + table_question_answering_pipeline = 'table-question-answering-pipeline' sentence_embedding = 'sentence-embedding' passage_ranking = 'passage-ranking' relation_extraction = 'relation-extraction' @@ -296,6 +298,7 @@ class Preprocessors(object): fill_mask_ponet = 'fill-mask-ponet' faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' conversational_text_to_sql = 'conversational-text-to-sql' + table_question_answering_preprocessor = 'table-question-answering-preprocessor' re_tokenizer = 're-tokenizer' document_segmentation = 'document-segmentation' diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index d411f1fb..443cb214 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -24,6 +24,7 @@ if TYPE_CHECKING: from .space import SpaceForDialogIntent from .space import SpaceForDialogModeling from .space import SpaceForDialogStateTracking + from .table_question_answering import TableQuestionAnswering from .task_models import (InformationExtractionModel, SequenceClassificationModel, SingleBackboneTaskModelBase, @@ -64,6 +65,7 @@ else: 'SingleBackboneTaskModelBase', 'TokenClassificationModel' ], 'token_classification': ['SbertForTokenClassification'], + 'table_question_answering': ['TableQuestionAnswering'], 'sentence_embedding': ['SentenceEmbedding'], 'passage_ranking': ['PassageRanking'], } diff --git a/modelscope/models/nlp/star3/__init__.py b/modelscope/models/nlp/star3/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/nlp/star3/configuration_star3.py b/modelscope/models/nlp/star3/configuration_star3.py new file mode 100644 index 00000000..d49c70c9 --- /dev/null +++ b/modelscope/models/nlp/star3/configuration_star3.py @@ -0,0 +1,128 @@ +# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# Copyright 2021-2022 The Alibaba DAMO Team Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT configuration.""" + +from __future__ import absolute_import, division, print_function +import copy +import logging +import math +import os +import shutil +import tarfile +import tempfile +from pathlib import Path +from typing import Union + +import json +import numpy as np +import torch +import torch_scatter +from icecream import ic +from torch import nn +from torch.nn import CrossEntropyLoss + +logger = logging.getLogger(__name__) + + +class Star3Config(object): + """Configuration class to store the configuration of a `Star3Model`. + """ + + def __init__(self, + vocab_size_or_config_json_file, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02): + """Constructs Star3Config. + + Args: + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `Star3Model`. + hidden_size: Size of the encoder layers and the pooler layer. + num_hidden_layers: Number of hidden layers in the Transformer encoder. + num_attention_heads: Number of attention heads for each attention layer in + the Transformer encoder. + intermediate_size: The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder. + hidden_act: The non-linear activation function (function or string) in the + encoder and pooler. If string, "gelu", "relu" and "swish" are supported. + hidden_dropout_prob: The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob: The dropout ratio for the attention + probabilities. + max_position_embeddings: The maximum sequence length that this model might + ever be used with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + type_vocab_size: The vocabulary size of the `token_type_ids` passed into `Star3Model`. + initializer_range: The sttdev of the truncated_normal_initializer for + initializing all weight matrices. + """ + if isinstance(vocab_size_or_config_json_file, str): + with open( + vocab_size_or_config_json_file, 'r', + encoding='utf-8') as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.vocab_size = vocab_size_or_config_json_file + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + else: + raise ValueError( + 'First argument must be either a vocabulary size (int)' + 'or the path to a pretrained model config file (str)') + + @classmethod + def from_dict(cls, json_object): + """Constructs a `Star3Config` from a Python dictionary of parameters.""" + config = Star3Config(vocab_size_or_config_json_file=-1) + for key, value in json_object.items(): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `Star3Config` from a json file of parameters.""" + with open(json_file, 'r', encoding='utf-8') as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + '\n' diff --git a/modelscope/models/nlp/star3/modeling_star3.py b/modelscope/models/nlp/star3/modeling_star3.py new file mode 100644 index 00000000..ed5ea1b3 --- /dev/null +++ b/modelscope/models/nlp/star3/modeling_star3.py @@ -0,0 +1,1023 @@ +# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# Copyright 2021-2022 The Alibaba DAMO Team Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" + +from __future__ import absolute_import, division, print_function +import copy +import logging +import math +import os +import shutil +import tarfile +import tempfile +from pathlib import Path +from typing import Union + +import json +import numpy as np +import torch +import torch_scatter +from torch import nn +from torch.nn import CrossEntropyLoss + +from modelscope.models.nlp.star3.configuration_star3 import Star3Config +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger + +logger = get_logger() + +CONFIG_NAME = ModelFile.CONFIGURATION +WEIGHTS_NAME = ModelFile.TORCH_MODEL_BIN_FILE + + +def gelu(x): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {'gelu': gelu, 'relu': torch.nn.functional.relu, 'swish': swish} + + +class BertLayerNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the square root). + """ + super(BertLayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + + def __init__(self, config): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, + config.hidden_size) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + self.match_type_embeddings = nn.Embedding(11, config.hidden_size) + self.type_embeddings = nn.Embedding(6, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, + input_ids, + header_ids, + token_type_ids=None, + match_type_ids=None, + l_hs=None, + header_len=None, + type_idx=None, + col_dict_list=None, + ids=None, + header_flatten_tokens=None, + header_flatten_index=None, + header_flatten_output=None, + token_column_id=None, + token_column_mask=None, + column_start_index=None, + headers_length=None): + seq_length = input_ids.size(1) + position_ids = torch.arange( + seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + words_embeddings = self.word_embeddings(input_ids) + header_embeddings = self.word_embeddings(header_ids) + + # header mean pooling + header_flatten_embeddings = self.word_embeddings(header_flatten_tokens) + header_flatten_index = header_flatten_index.reshape( + (-1, header_flatten_index.shape[1], 1)) + header_flatten_index = header_flatten_index.repeat( + 1, 1, header_flatten_embeddings.shape[2]) + header_flatten_output = header_flatten_output.reshape( + (-1, header_flatten_output.shape[1], 1)) + header_flatten_output = header_flatten_output.repeat( + 1, 1, header_flatten_embeddings.shape[2]) + header_embeddings = torch_scatter.scatter_mean( + header_flatten_embeddings, + header_flatten_index, + out=header_flatten_output, + dim=1) + token_column_id = token_column_id.reshape( + (-1, token_column_id.shape[1], 1)) + token_column_id = token_column_id.repeat( + (1, 1, header_embeddings.shape[2])) + token_column_mask = token_column_mask.reshape( + (-1, token_column_mask.shape[1], 1)) + token_column_mask = token_column_mask.repeat( + (1, 1, header_embeddings.shape[2])) + token_header_embeddings = torch.gather(header_embeddings, 1, + token_column_id) + words_embeddings = words_embeddings * (1.0 - token_column_mask) + \ + token_header_embeddings * token_column_mask + + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = words_embeddings + position_embeddings + token_type_embeddings + + if match_type_ids is not None: + match_type_embeddings = self.match_type_embeddings(match_type_ids) + embeddings += match_type_embeddings + + if type_idx is not None: + type_embeddings = self.type_embeddings(type_idx) + embeddings += type_embeddings + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + + return embeddings + + +class BertSelfAttention(nn.Module): + + def __init__(self, config): + super(BertSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + 'The hidden size (%d) is not a multiple of the number of attention ' + 'heads (%d)' % + (config.hidden_size, config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size + / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask, schema_link_matrix=None): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BertSelfAttentionWithRelationsRAT(nn.Module): + ''' + Adapted from https://github.com/microsoft/rat-sql/blob/master/ratsql/models/transformer.py + ''' + + def __init__(self, config): + super(BertSelfAttentionWithRelationsRAT, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + 'The hidden size (%d) is not a multiple of the number of attention ' + 'heads (%d)' % + (config.hidden_size, config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size + / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + self.relation_k_emb = nn.Embedding( + 7, config.hidden_size // config.num_attention_heads) + self.relation_v_emb = nn.Embedding( + 7, config.hidden_size // config.num_attention_heads) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask, relation): + ''' + relation is [batch, seq len, seq len] + ''' + mixed_query_layer = self.query( + hidden_states) # [batch, seq len, hidden dim] + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + relation_k = self.relation_k_emb( + relation) # [batch, seq len, seq len, head dim] + relation_v = self.relation_v_emb( + relation) # [batch, seq len, seq len, head dim] + + query_layer = self.transpose_for_scores( + mixed_query_layer) # [batch, num attn heads, seq len, head dim] + key_layer = self.transpose_for_scores( + mixed_key_layer) # [batch, num attn heads, seq len, head dim] + value_layer = self.transpose_for_scores( + mixed_value_layer) # [batch, num attn heads, seq len, head dim] + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose( + -1, -2)) # [batch, num attn heads, seq len, seq len] + + # relation_k_t is [batch, seq len, head dim, seq len] + relation_k_t = relation_k.transpose(-2, -1) + # query_layer_t is [batch, seq len, num attn heads, head dim] + query_layer_t = query_layer.permute(0, 2, 1, 3) + # relation_attention_scores is [batch, seq len, num attn heads, seq len] + relation_attention_scores = torch.matmul(query_layer_t, relation_k_t) + # relation_attention_scores_t is [batch, num attn heads, seq len, seq len] + relation_attention_scores_t = relation_attention_scores.permute( + 0, 2, 1, 3) + + merged_attention_scores = (attention_scores + + relation_attention_scores_t) / math.sqrt( + self.attention_head_size) + + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + merged_attention_scores = merged_attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(merged_attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + # attention_probs is [batch, num attn heads, seq len, seq len] + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + + # attention_probs_t is [batch, seq len, num attn heads, seq len] + attention_probs_t = attention_probs.permute(0, 2, 1, 3) + + # [batch, seq len, num attn heads, seq len] + # * [batch, seq len, seq len, head dim] + # = [batch, seq len, num attn heads, head dim] + context_relation = torch.matmul(attention_probs_t, relation_v) + + # context_relation_t is [batch, num attn heads, seq len, head dim] + context_relation_t = context_relation.permute(0, 2, 1, 3) + + merged_context_layer = context_layer + context_relation_t + merged_context_layer = merged_context_layer.permute(0, 2, 1, + 3).contiguous() + new_context_layer_shape = merged_context_layer.size()[:-2] + ( + self.all_head_size, ) + merged_context_layer = merged_context_layer.view( + *new_context_layer_shape) + return merged_context_layer + + +class BertSelfAttentionWithRelationsTableformer(nn.Module): + + def __init__(self, config): + super(BertSelfAttentionWithRelationsTableformer, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + 'The hidden size (%d) is not a multiple of the number of attention ' + 'heads (%d)' % + (config.hidden_size, config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size + / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.schema_link_embeddings = nn.Embedding(7, self.num_attention_heads) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask, relation): + ''' + relation is [batch, seq len, seq len] + ''' + mixed_query_layer = self.query( + hidden_states) # [batch, seq len, hidden dim] + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + schema_link_embeddings = self.schema_link_embeddings( + relation) # [batch, seq len, seq len, 1] + schema_link_embeddings = schema_link_embeddings.permute(0, 3, 1, 2) + + query_layer = self.transpose_for_scores( + mixed_query_layer) # [batch, num attn heads, seq len, head dim] + key_layer = self.transpose_for_scores( + mixed_key_layer) # [batch, num attn heads, seq len, head dim] + value_layer = self.transpose_for_scores( + mixed_value_layer) # [batch, num attn heads, seq len, head dim] + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose( + -1, -2)) # [batch, num attn heads, seq len, seq len] + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + + merged_attention_scores = attention_scores + schema_link_embeddings + + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + merged_attention_scores = merged_attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(merged_attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + # attention_probs is [batch, num attn heads, seq len, seq len] + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BertSelfOutput(nn.Module): + + def __init__(self, config): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, config, schema_link_module='none'): + super(BertAttention, self).__init__() + if schema_link_module == 'none': + self.self = BertSelfAttention(config) + if schema_link_module == 'rat': + self.self = BertSelfAttentionWithRelationsRAT(config) + if schema_link_module == 'add': + self.self = BertSelfAttentionWithRelationsTableformer(config) + self.output = BertSelfOutput(config) + + def forward(self, input_tensor, attention_mask, schema_link_matrix=None): + self_output = self.self(input_tensor, attention_mask, + schema_link_matrix) + attention_output = self.output(self_output, input_tensor) + return attention_output + + +class BertIntermediate(nn.Module): + + def __init__(self, config): + super(BertIntermediate, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.intermediate_act_fn = ACT2FN[config.hidden_act] \ + if isinstance(config.hidden_act, str) else config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, config): + super(BertOutput, self).__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, config, schema_link_module='none'): + super(BertLayer, self).__init__() + self.attention = BertAttention( + config, schema_link_module=schema_link_module) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward(self, hidden_states, attention_mask, schema_link_matrix=None): + attention_output = self.attention(hidden_states, attention_mask, + schema_link_matrix) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class SqlBertEncoder(nn.Module): + + def __init__(self, layers, config): + super(SqlBertEncoder, self).__init__() + layer = BertLayer(config) + self.layer = nn.ModuleList( + [copy.deepcopy(layer) for _ in range(layers)]) + + def forward(self, + hidden_states, + attention_mask, + output_all_encoded_layers=True): + all_encoder_layers = [] + for layer_module in self.layer: + hidden_states = layer_module(hidden_states, attention_mask) + if output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + if not output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + return all_encoder_layers + + +class BertEncoder(nn.Module): + + def __init__(self, config, schema_link_module='none'): + super(BertEncoder, self).__init__() + layer = BertLayer(config, schema_link_module=schema_link_module) + self.layer = nn.ModuleList( + [copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) + + def forward(self, + hidden_states, + attention_mask, + all_schema_link_matrix=None, + all_schema_link_mask=None, + output_all_encoded_layers=True): + all_encoder_layers = [] + for layer_module in self.layer: + hidden_states = layer_module(hidden_states, attention_mask, + all_schema_link_matrix) + if output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + if not output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + return all_encoder_layers + + +class BertPooler(nn.Module): + + def __init__(self, config): + super(BertPooler, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super(BertPredictionHeadTransform, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.transform_act_fn = ACT2FN[config.hidden_act] \ + if isinstance(config.hidden_act, str) else config.hidden_act + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + + def __init__(self, config, bert_model_embedding_weights): + super(BertLMPredictionHead, self).__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + bert_model_embedding_weights.size(1), + bert_model_embedding_weights.size(0), + bias=False) + self.decoder.weight = bert_model_embedding_weights + self.bias = nn.Parameter( + torch.zeros(bert_model_embedding_weights.size(0))) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + self.bias + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + + def __init__(self, config, bert_model_embedding_weights): + super(BertOnlyMLMHead, self).__init__() + self.predictions = BertLMPredictionHead(config, + bert_model_embedding_weights) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + + def __init__(self, config): + super(BertOnlyNSPHead, self).__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + + def __init__(self, config, bert_model_embedding_weights): + super(BertPreTrainingHeads, self).__init__() + self.predictions = BertLMPredictionHead(config, + bert_model_embedding_weights) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class PreTrainedBertModel(nn.Module): + """ An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + + def __init__(self, config, *inputs, **kwargs): + super(PreTrainedBertModel, self).__init__() + if not isinstance(config, Star3Config): + raise ValueError( + 'Parameter config in `{}(config)` should be an instance of class `Star3Config`. ' + 'To create a model from a Google pretrained model use ' + '`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`'.format( + self.__class__.__name__, self.__class__.__name__)) + self.config = config + + def init_bert_weights(self, module): + """ Initialize the weights. + """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + @classmethod + def from_pretrained(cls, + pretrained_model_name, + state_dict=None, + cache_dir=None, + *inputs, + **kwargs): + """ + Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict. + Download and cache the pre-trained model file if needed. + + Params: + pretrained_model_name: either: + - a str with the name of a pre-trained model to load selected in the list of: + . `bert-base-uncased` + . `bert-large-uncased` + . `bert-base-cased` + . `bert-large-cased` + . `bert-base-multilingual-uncased` + . `bert-base-multilingual-cased` + . `bert-base-chinese` + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance + cache_dir: an optional path to a folder in which the pre-trained models will be cached. + state_dict: an optional state dictionnary (collections.OrderedDict object) + to use instead of Google pre-trained models + *inputs, **kwargs: additional input for the specific Bert class + (ex: num_labels for BertForSequenceClassification) + """ + resolved_archive_file = pretrained_model_name + # redirect to the cache, if necessary + tempdir = None + if os.path.isdir(resolved_archive_file): + serialization_dir = resolved_archive_file + else: + # Extract archive to temp dir + tempdir = tempfile.mkdtemp() + logger.info('extracting archive file {} to temp dir {}'.format( + resolved_archive_file, tempdir)) + with tarfile.open(resolved_archive_file, 'r:gz') as archive: + archive.extractall(tempdir) + serialization_dir = tempdir + # Load config + config_file = os.path.join(serialization_dir, CONFIG_NAME) + config = Star3Config.from_json_file(config_file) + logger.info('Model config {}'.format(config)) + # Instantiate model. + model = cls(config, *inputs, **kwargs) + if state_dict is None: + weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) + state_dict = torch.load(weights_path) + + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if 'gamma' in key: + new_key = key.replace('gamma', 'weight') + if 'beta' in key: + new_key = key.replace('beta', 'bias') + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=''): + local_metadata = {} if metadata is None else metadata.get( + prefix[:-1], {}) + module._load_from_state_dict(state_dict, prefix, local_metadata, + True, missing_keys, unexpected_keys, + error_msgs) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(model, prefix='' if hasattr(model, 'bert') else 'bert.') + if len(missing_keys) > 0: + logger.info( + 'Weights of {} not initialized from pretrained model: {}'. + format(model.__class__.__name__, missing_keys)) + print() + print('*' * 10, 'WARNING missing weights', '*' * 10) + print('Weights of {} not initialized from pretrained model: {}'. + format(model.__class__.__name__, missing_keys)) + print() + if len(unexpected_keys) > 0: + logger.info( + 'Weights from pretrained model not used in {}: {}'.format( + model.__class__.__name__, unexpected_keys)) + print() + print('*' * 10, 'WARNING unexpected weights', '*' * 10) + print('Weights from pretrained model not used in {}: {}'.format( + model.__class__.__name__, unexpected_keys)) + print() + if tempdir: + # Clean up temp dir + shutil.rmtree(tempdir) + return model + + +class Star3Model(PreTrainedBertModel): + """Star3Model model ("Bidirectional Embedding Representations from a Transformer pretrained on STAR3.0"). + + Params: + config: a Star3Config class instance with the configuration to build a new model + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output + as described below. Default: `True`. + + Outputs: Tuple of (encoded_layers, pooled_output) + `encoded_layers`: controled by `output_all_encoded_layers` argument: + - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end + of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each + encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], + - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding + to the last attention block of shape [batch_size, sequence_length, hidden_size], + `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a + classifier pretrained on top of the hidden state associated to the first character of the + input (`CLF`) to train on the Next-Sentence task (see BERT's paper). + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = modeling.Star3Config(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = modeling.Star3Model(config=config) + all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) + ``` + """ + + def __init__(self, config, schema_link_module='none'): + super(Star3Model, self).__init__(config) + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder( + config, schema_link_module=schema_link_module) + self.pooler = BertPooler(config) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + header_ids, + token_order_ids=None, + token_type_ids=None, + attention_mask=None, + match_type_ids=None, + l_hs=None, + header_len=None, + type_ids=None, + col_dict_list=None, + ids=None, + header_flatten_tokens=None, + header_flatten_index=None, + header_flatten_output=None, + token_column_id=None, + token_column_mask=None, + column_start_index=None, + headers_length=None, + all_schema_link_matrix=None, + all_schema_link_mask=None, + output_all_encoded_layers=True): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + + # Bowen: comment out the following line for Pytorch >= 1.5 + # https://github.com/huggingface/transformers/issues/3936#issuecomment-793764416 + # extended_attention_mask = extended_attention_mask.to(self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self.embeddings( + input_ids, header_ids, token_type_ids, match_type_ids, l_hs, + header_len, type_ids, col_dict_list, ids, header_flatten_tokens, + header_flatten_index, header_flatten_output, token_column_id, + token_column_mask, column_start_index, headers_length) + encoded_layers = self.encoder( + embedding_output, + extended_attention_mask, + all_schema_link_matrix=all_schema_link_matrix, + all_schema_link_mask=all_schema_link_mask, + output_all_encoded_layers=output_all_encoded_layers) + sequence_output = encoded_layers[-1] + pooled_output = self.pooler(sequence_output) + if not output_all_encoded_layers: + encoded_layers = encoded_layers[-1] + return encoded_layers, pooled_output + + +class Seq2SQL(nn.Module): + + def __init__(self, iS, hS, lS, dr, n_cond_ops, n_agg_ops, n_action_ops, + max_select_num, max_where_num, device): + super(Seq2SQL, self).__init__() + self.iS = iS + self.hS = hS + self.ls = lS + self.dr = dr + self.device = device + + self.n_agg_ops = n_agg_ops + self.n_cond_ops = n_cond_ops + self.n_action_ops = n_action_ops + self.max_select_num = max_select_num + self.max_where_num = max_where_num + + self.w_sss_model = nn.Linear(iS, max_where_num) + self.w_sse_model = nn.Linear(iS, max_where_num) + self.s_ht_model = nn.Linear(iS, max_select_num) + self.wc_ht_model = nn.Linear(iS, max_where_num) + + self.select_agg_model = nn.Linear(iS * max_select_num, + n_agg_ops * max_select_num) + self.w_op_model = nn.Linear(iS * max_where_num, + n_cond_ops * max_where_num) + + self.conn_model = nn.Linear(iS, 3) + self.action_model = nn.Linear(iS, n_action_ops + 1) + self.slen_model = nn.Linear(iS, max_select_num + 1) + self.wlen_model = nn.Linear(iS, max_where_num + 1) + + def forward(self, wemb_layer, l_n, l_hs, start_index, column_index, tokens, + ids): + # chunk input lists for multi-gpu + max_l_n = max(l_n) + max_l_hs = max(l_hs) + l_n = np.array(l_n)[ids.cpu().numpy()].tolist() + l_hs = np.array(l_hs)[ids.cpu().numpy()].tolist() + start_index = np.array(start_index)[ids.cpu().numpy()].tolist() + column_index = np.array(column_index)[ids.cpu().numpy()].tolist() + # tokens = np.array(tokens)[ids.cpu().numpy()].tolist() + + conn_index = [] + slen_index = [] + wlen_index = [] + action_index = [] + where_op_index = [] + select_agg_index = [] + header_pos_index = [] + query_index = [] + for ib, elem in enumerate(start_index): + # [SEP] conn [SEP] wlen [SEP] (wop [SEP])*wn slen [SEP] (agg [SEP])*sn + action_index.append(elem + 1) + conn_index.append(elem + 2) + wlen_index.append(elem + 3) + woi = [elem + 4 + i for i in range(self.max_where_num)] + + slen_index.append(elem + 4 + self.max_where_num) + sai = [ + elem + 5 + self.max_where_num + i + for i in range(self.max_select_num) + ] + where_op_index.append(woi) + select_agg_index.append(sai) + + qilist = [i for i in range(l_n[ib] + 2)] + [l_n[ib] + 1] * ( + max_l_n - l_n[ib]) + query_index.append(qilist) + + index = [column_index[ib] + i for i in range(0, l_hs[ib], 1)] + index += [index[0] for _ in range(max_l_hs - len(index))] + header_pos_index.append(index) + + # print("tokens: ", tokens) + # print("conn_index: ", conn_index, "start_index: ", start_index) + conn_index = torch.tensor(conn_index, dtype=torch.long).to(self.device) + slen_index = torch.tensor(slen_index, dtype=torch.long).to(self.device) + wlen_index = torch.tensor(wlen_index, dtype=torch.long).to(self.device) + action_index = torch.tensor( + action_index, dtype=torch.long).to(self.device) + where_op_index = torch.tensor( + where_op_index, dtype=torch.long).to(self.device) + select_agg_index = torch.tensor( + select_agg_index, dtype=torch.long).to(self.device) + query_index = torch.tensor( + query_index, dtype=torch.long).to(self.device) + header_index = torch.tensor( + header_pos_index, dtype=torch.long).to(self.device) + + bS = len(l_n) + conn_emb = torch.zeros([bS, self.iS]).to(self.device) + slen_emb = torch.zeros([bS, self.iS]).to(self.device) + wlen_emb = torch.zeros([bS, self.iS]).to(self.device) + action_emb = torch.zeros([bS, self.iS]).to(self.device) + wo_emb = torch.zeros([bS, self.max_where_num, self.iS]).to(self.device) + sa_emb = torch.zeros([bS, self.max_select_num, + self.iS]).to(self.device) + qv_emb = torch.zeros([bS, max_l_n + 2, self.iS]).to(self.device) + ht_emb = torch.zeros([bS, max_l_hs, self.iS]).to(self.device) + for i in range(bS): + conn_emb[i, :] = wemb_layer[i].index_select(0, conn_index[i]) + slen_emb[i, :] = wemb_layer[i].index_select(0, slen_index[i]) + wlen_emb[i, :] = wemb_layer[i].index_select(0, wlen_index[i]) + action_emb[i, :] = wemb_layer[i].index_select(0, action_index[i]) + + wo_emb[i, :, :] = wemb_layer[i].index_select( + 0, where_op_index[i, :]) + sa_emb[i, :, :] = wemb_layer[i].index_select( + 0, select_agg_index[i, :]) + qv_emb[i, :, :] = wemb_layer[i].index_select(0, query_index[i, :]) + ht_emb[i, :, :] = wemb_layer[i].index_select(0, header_index[i, :]) + + s_cco = self.conn_model(conn_emb.reshape(-1, self.iS)).reshape(bS, 3) + s_slen = self.slen_model(slen_emb.reshape(-1, self.iS)).reshape( + bS, self.max_select_num + 1) + s_wlen = self.wlen_model(wlen_emb.reshape(-1, self.iS)).reshape( + bS, self.max_where_num + 1) + s_action = self.action_model(action_emb.reshape(-1, self.iS)).reshape( + bS, self.n_action_ops + 1) + wo_output = self.w_op_model( + wo_emb.reshape(-1, self.iS * self.max_where_num)).reshape( + bS, -1, self.n_cond_ops) + + wc_output = self.wc_ht_model(ht_emb.reshape(-1, self.iS)).reshape( + bS, -1, self.max_where_num).transpose(1, 2) + + wv_ss = self.w_sss_model(qv_emb.reshape(-1, self.iS)).reshape( + bS, -1, self.max_where_num).transpose(1, 2) + wv_se = self.w_sse_model(qv_emb.reshape(-1, self.iS)).reshape( + bS, -1, self.max_where_num).transpose(1, 2) + + sc_output = self.s_ht_model(ht_emb.reshape(-1, self.iS)).reshape( + bS, -1, self.max_select_num).transpose(1, 2) + sa_output = self.select_agg_model( + sa_emb.reshape(-1, self.iS * self.max_select_num)).reshape( + bS, -1, self.n_agg_ops) + + return s_action, sc_output, sa_output, s_cco, wc_output, wo_output, ( + wv_ss, wv_se), (s_slen, s_wlen) diff --git a/modelscope/models/nlp/table_question_answering.py b/modelscope/models/nlp/table_question_answering.py new file mode 100644 index 00000000..19fdf178 --- /dev/null +++ b/modelscope/models/nlp/table_question_answering.py @@ -0,0 +1,747 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Dict, Optional + +import numpy +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import BertTokenizer + +from modelscope.metainfo import Models +from modelscope.models.base import Model, Tensor +from modelscope.models.builder import MODELS +from modelscope.models.nlp.star3.configuration_star3 import Star3Config +from modelscope.models.nlp.star3.modeling_star3 import Seq2SQL, Star3Model +from modelscope.preprocessors.star3.fields.struct import Constant +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.device import verify_device + +__all__ = ['TableQuestionAnswering'] + + +@MODELS.register_module( + Tasks.table_question_answering, module_name=Models.star3) +class TableQuestionAnswering(Model): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the table-question-answering model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + self.tokenizer = BertTokenizer( + os.path.join(model_dir, ModelFile.VOCAB_FILE)) + device_name = kwargs.get('device', 'gpu') + verify_device(device_name) + self._device_name = device_name + + state_dict = torch.load( + os.path.join(self.model_dir, ModelFile.TORCH_MODEL_BIN_FILE), + map_location='cpu') + + self.backbone_config = Star3Config.from_json_file( + os.path.join(self.model_dir, ModelFile.CONFIGURATION)) + self.backbone_model = Star3Model( + config=self.backbone_config, schema_link_module='rat') + self.backbone_model.load_state_dict(state_dict['backbone_model']) + + constant = Constant() + self.agg_ops = constant.agg_ops + self.cond_ops = constant.cond_ops + self.cond_conn_ops = constant.cond_conn_ops + self.action_ops = constant.action_ops + self.max_select_num = constant.max_select_num + self.max_where_num = constant.max_where_num + self.col_type_dict = constant.col_type_dict + self.schema_link_dict = constant.schema_link_dict + n_cond_ops = len(self.cond_ops) + n_agg_ops = len(self.agg_ops) + n_action_ops = len(self.action_ops) + iS = self.backbone_config.hidden_size + self.head_model = Seq2SQL(iS, 100, 2, 0.0, n_cond_ops, n_agg_ops, + n_action_ops, self.max_select_num, + self.max_where_num, self._device_name) + self.head_model.load_state_dict(state_dict['head_model'], strict=False) + + self.backbone_model.to(self._device_name) + self.head_model.to(self._device_name) + + def convert_string(self, pr_wvi, nlu, nlu_tt): + convs = [] + for b, nlu1 in enumerate(nlu): + conv_dict = {} + nlu_tt1 = nlu_tt[b] + idx = 0 + convflag = True + for i, ntok in enumerate(nlu_tt1): + if idx >= len(nlu1): + convflag = False + break + + if ntok.startswith('##'): + ntok = ntok.replace('##', '') + + tok = nlu1[idx:idx + 1].lower() + if ntok == tok: + conv_dict[i] = [idx, idx + 1] + idx += 1 + elif ntok == '#': + conv_dict[i] = [idx, idx] + elif ntok == '[UNK]': + conv_dict[i] = [idx, idx + 1] + j = i + 1 + idx += 1 + if idx < len(nlu1) and j < len( + nlu_tt1) and nlu_tt1[j] != '[UNK]': + while idx < len(nlu1): + val = nlu1[idx:idx + 1].lower() + if nlu_tt1[j].startswith(val): + break + idx += 1 + conv_dict[i][1] = idx + elif tok in ntok: + startid = idx + idx += 1 + while idx < len(nlu1): + tok += nlu1[idx:idx + 1].lower() + if ntok == tok: + conv_dict[i] = [startid, idx + 1] + break + idx += 1 + idx += 1 + else: + convflag = False + + conv = [] + if convflag: + for pr_wvi1 in pr_wvi[b]: + s1, e1 = conv_dict[pr_wvi1[0]] + s2, e2 = conv_dict[pr_wvi1[1]] + newidx = pr_wvi1[1] + while newidx + 1 < len( + nlu_tt1) and s2 == e2 and nlu_tt1[newidx] == '#': + newidx += 1 + s2, e2 = conv_dict[newidx] + if newidx + 1 < len(nlu_tt1) and nlu_tt1[ + newidx + 1].startswith('##'): + s2, e2 = conv_dict[newidx + 1] + phrase = nlu1[s1:e2] + conv.append(phrase) + else: + for pr_wvi1 in pr_wvi[b]: + phrase = ''.join(nlu_tt1[pr_wvi1[0]:pr_wvi1[1] + + 1]).replace('##', '') + conv.append(phrase) + convs.append(conv) + + return convs + + def get_fields_info(self, t1s, tables, train=True): + nlu, nlu_t, sql_i, q_know, t_know, action, hs_t, types, units, his_sql, schema_link = \ + [], [], [], [], [], [], [], [], [], [], [] + for t1 in t1s: + nlu.append(t1['question']) + nlu_t.append(t1['question_tok']) + hs_t.append(t1['header_tok']) + q_know.append(t1['bertindex_knowledge']) + t_know.append(t1['header_knowledge']) + types.append(t1['types']) + units.append(t1['units']) + his_sql.append(t1.get('history_sql', None)) + schema_link.append(t1.get('schema_link', [])) + if train: + action.append(t1.get('action', [0])) + sql_i.append(t1['sql']) + + return nlu, nlu_t, sql_i, q_know, t_know, action, hs_t, types, units, his_sql, schema_link + + def get_history_select_where(self, his_sql, header_len): + if his_sql is None: + return [0], [0] + + sel = [] + for seli in his_sql['sel']: + if seli + 1 < header_len and seli + 1 not in sel: + sel.append(seli + 1) + + whe = [] + for condi in his_sql['conds']: + if condi[0] + 1 < header_len and condi[0] + 1 not in whe: + whe.append(condi[0] + 1) + + if len(sel) == 0: + sel.append(0) + if len(whe) == 0: + whe.append(0) + + sel.sort() + whe.sort() + + return sel, whe + + def get_types_ids(self, col_type): + for key, type_ids in self.col_type_dict.items(): + if key in col_type.lower(): + return type_ids + return self.col_type_dict['null'] + + def generate_inputs(self, nlu1_tok, hs_t_1, type_t, unit_t, his_sql, + q_know, t_know, s_link): + tokens = [] + orders = [] + types = [] + segment_ids = [] + matchs = [] + col_dict = {} + schema_tok = [] + + tokens.append('[CLS]') + orders.append(0) + types.append(0) + i_st_nlu = len(tokens) + + matchs.append(0) + segment_ids.append(0) + for idx, token in enumerate(nlu1_tok): + if q_know[idx] == 100: + break + elif q_know[idx] >= 5: + matchs.append(1) + else: + matchs.append(q_know[idx] + 1) + tokens.append(token) + orders.append(0) + types.append(0) + segment_ids.append(0) + + i_ed_nlu = len(tokens) + + tokens.append('[SEP]') + orders.append(0) + types.append(0) + matchs.append(0) + segment_ids.append(0) + + sel, whe = self.get_history_select_where(his_sql, len(hs_t_1)) + + if len(sel) == 1 and sel[0] == 0 \ + and len(whe) == 1 and whe[0] == 0: + pass + else: + tokens.append('select') + orders.append(0) + types.append(0) + matchs.append(10) + segment_ids.append(0) + + for seli in sel: + tokens.append('[PAD]') + orders.append(0) + types.append(0) + matchs.append(10) + segment_ids.append(0) + col_dict[len(tokens) - 1] = seli + + tokens.append('where') + orders.append(0) + types.append(0) + matchs.append(10) + segment_ids.append(0) + + for whei in whe: + tokens.append('[PAD]') + orders.append(0) + types.append(0) + matchs.append(10) + segment_ids.append(0) + col_dict[len(tokens) - 1] = whei + + tokens.append('[SEP]') + orders.append(0) + types.append(0) + matchs.append(10) + segment_ids.append(0) + + column_start = len(tokens) + i_hds_f = [] + header_flatten_tokens, header_flatten_index = [], [] + for i, hds11 in enumerate(hs_t_1): + if len(unit_t[i]) == 1 and unit_t[i][0] == 'null': + temp_header_tokens = hds11 + else: + temp_header_tokens = hds11 + unit_t[i] + schema_tok.append(temp_header_tokens) + header_flatten_tokens.extend(temp_header_tokens) + header_flatten_index.extend([i + 1] * len(temp_header_tokens)) + i_st_hd_f = len(tokens) + tokens += ['[PAD]'] + orders.append(0) + types.append(self.get_types_ids(type_t[i])) + i_ed_hd_f = len(tokens) + col_dict[len(tokens) - 1] = i + i_hds_f.append((i_st_hd_f, i_ed_hd_f)) + if i == 0: + matchs.append(6) + else: + matchs.append(t_know[i - 1] + 6) + segment_ids.append(1) + + tokens.append('[SEP]') + orders.append(0) + types.append(0) + matchs.append(0) + segment_ids.append(1) + + # position where + # [SEP] + start_ids = len(tokens) - 1 + + tokens.append('action') # action + orders.append(1) + types.append(0) + matchs.append(0) + segment_ids.append(1) + + tokens.append('connect') # column + orders.append(1) + types.append(0) + matchs.append(0) + segment_ids.append(1) + + tokens.append('allen') # select len + orders.append(1) + types.append(0) + matchs.append(0) + segment_ids.append(1) + + for x in range(self.max_where_num): + tokens.append('act') # op + orders.append(2 + x) + types.append(0) + matchs.append(0) + segment_ids.append(1) + + tokens.append('size') # where len + orders.append(1) + types.append(0) + matchs.append(0) + segment_ids.append(1) + + for x in range(self.max_select_num): + tokens.append('focus') # agg + orders.append(2 + x) + types.append(0) + matchs.append(0) + segment_ids.append(1) + + i_nlu = (i_st_nlu, i_ed_nlu) + + schema_link_matrix = numpy.zeros((len(tokens), len(tokens)), + dtype='int32') + schema_link_mask = numpy.zeros((len(tokens), len(tokens)), + dtype='float32') + for relation in s_link: + if relation['label'] in ['col', 'val']: + [q_st, q_ed] = relation['question_index'] + cid = max(0, relation['column_index']) + schema_link_matrix[ + i_st_nlu + q_st: i_st_nlu + q_ed + 1, + column_start + cid + 1: column_start + cid + 1 + 1] = \ + self.schema_link_dict[relation['label'] + '_middle'] + schema_link_matrix[ + i_st_nlu + q_st, + column_start + cid + 1: column_start + cid + 1 + 1] = \ + self.schema_link_dict[relation['label'] + '_start'] + schema_link_matrix[ + i_st_nlu + q_ed, + column_start + cid + 1: column_start + cid + 1 + 1] = \ + self.schema_link_dict[relation['label'] + '_end'] + schema_link_mask[i_st_nlu + q_st:i_st_nlu + q_ed + 1, + column_start + cid + 1:column_start + cid + 1 + + 1] = 1.0 + + return tokens, orders, types, segment_ids, matchs, \ + i_nlu, i_hds_f, start_ids, column_start, col_dict, schema_tok, \ + header_flatten_tokens, header_flatten_index, schema_link_matrix, schema_link_mask + + def gen_l_hpu(self, i_hds): + """ + Treat columns as if it is a batch of natural language utterance + with batch-size = # of columns * # of batch_size + i_hds = [(17, 18), (19, 21), (22, 23), (24, 25), (26, 29), (30, 34)]) + """ + l_hpu = [] + for i_hds1 in i_hds: + for i_hds11 in i_hds1: + l_hpu.append(i_hds11[1] - i_hds11[0]) + + return l_hpu + + def get_bert_output(self, model_bert, tokenizer, nlu_t, hs_t, col_types, + units, his_sql, q_know, t_know, schema_link): + """ + Here, input is toknized further by WordPiece (WP) tokenizer and fed into BERT. + + INPUT + :param model_bert: + :param tokenizer: WordPiece toknizer + :param nlu: Question + :param nlu_t: CoreNLP tokenized nlu. + :param hds: Headers + :param hs_t: None or 1st-level tokenized headers + :param max_seq_length: max input token length + + OUTPUT + tokens: BERT input tokens + nlu_tt: WP-tokenized input natural language questions + orig_to_tok_index: map the index of 1st-level-token to the index of 2nd-level-token + tok_to_orig_index: inverse map. + + """ + + l_n = [] + l_hs = [] # The length of columns for each batch + + input_ids = [] + order_ids = [] + type_ids = [] + segment_ids = [] + match_ids = [] + input_mask = [] + + i_nlu = [ + ] # index to retreive the position of contextual vector later. + i_hds = [] + tokens = [] + orders = [] + types = [] + matchs = [] + segments = [] + schema_link_matrix_list = [] + schema_link_mask_list = [] + start_index = [] + column_index = [] + col_dict_list = [] + header_list = [] + header_flatten_token_list = [] + header_flatten_tokenid_list = [] + header_flatten_index_list = [] + + header_tok_max_len = 0 + cur_max_length = 0 + + for b, nlu_t1 in enumerate(nlu_t): + hs_t1 = [hs_t[b][-1]] + hs_t[b][:-1] + type_t1 = [col_types[b][-1]] + col_types[b][:-1] + unit_t1 = [units[b][-1]] + units[b][:-1] + l_hs.append(len(hs_t1)) + + # [CLS] nlu [SEP] col1 [SEP] col2 [SEP] ...col-n [SEP] + # 2. Generate BERT inputs & indices. + tokens1, orders1, types1, segment1, match1, i_nlu1, i_hds_1, \ + start_idx, column_start, col_dict, schema_tok, \ + header_flatten_tokens, header_flatten_index, schema_link_matrix, schema_link_mask = \ + self.generate_inputs( + nlu_t1, hs_t1, type_t1, unit_t1, his_sql[b], + q_know[b], t_know[b], schema_link[b]) + + l_n.append(i_nlu1[1] - i_nlu1[0]) + start_index.append(start_idx) + column_index.append(column_start) + col_dict_list.append(col_dict) + tokens.append(tokens1) + orders.append(orders1) + types.append(types1) + segments.append(segment1) + matchs.append(match1) + i_nlu.append(i_nlu1) + i_hds.append(i_hds_1) + schema_link_matrix_list.append(schema_link_matrix) + schema_link_mask_list.append(schema_link_mask) + header_flatten_token_list.append(header_flatten_tokens) + header_flatten_index_list.append(header_flatten_index) + header_list.append(schema_tok) + header_max = max([len(schema_tok1) for schema_tok1 in schema_tok]) + if header_max > header_tok_max_len: + header_tok_max_len = header_max + + if len(tokens1) > cur_max_length: + cur_max_length = len(tokens1) + + if len(tokens1) > 512: + print('input too long!!! total_num:%d\t question:%s' % + (len(tokens1), ''.join(nlu_t1))) + + assert cur_max_length <= 512 + + for i, tokens1 in enumerate(tokens): + segment_ids1 = segments[i] + order_ids1 = orders[i] + type_ids1 = types[i] + match_ids1 = matchs[i] + input_ids1 = tokenizer.convert_tokens_to_ids(tokens1) + input_mask1 = [1] * len(input_ids1) + + while len(input_ids1) < cur_max_length: + input_ids1.append(0) + input_mask1.append(0) + segment_ids1.append(0) + order_ids1.append(0) + type_ids1.append(0) + match_ids1.append(0) + + if len(input_ids1) != cur_max_length: + print('Error: ', nlu_t1, tokens1, len(input_ids1), + cur_max_length) + + assert len(input_ids1) == cur_max_length + assert len(input_mask1) == cur_max_length + assert len(order_ids1) == cur_max_length + assert len(segment_ids1) == cur_max_length + assert len(match_ids1) == cur_max_length + assert len(type_ids1) == cur_max_length + + input_ids.append(input_ids1) + order_ids.append(order_ids1) + type_ids.append(type_ids1) + segment_ids.append(segment_ids1) + input_mask.append(input_mask1) + match_ids.append(match_ids1) + + header_len = [] + header_ids = [] + header_max_len = max( + [len(header_list1) for header_list1 in header_list]) + for header1 in header_list: + header_len1 = [] + header_ids1 = [] + for header_tok in header1: + header_len1.append(len(header_tok)) + header_tok_ids1 = tokenizer.convert_tokens_to_ids(header_tok) + while len(header_tok_ids1) < header_tok_max_len: + header_tok_ids1.append(0) + header_ids1.append(header_tok_ids1) + while len(header_ids1) < header_max_len: + header_ids1.append([0] * header_tok_max_len) + header_len.append(header_len1) + header_ids.append(header_ids1) + + for i, header_flatten_token in enumerate(header_flatten_token_list): + header_flatten_tokenid = tokenizer.convert_tokens_to_ids( + header_flatten_token) + header_flatten_tokenid_list.append(header_flatten_tokenid) + + # Convert to tensor + all_input_ids = torch.tensor( + input_ids, dtype=torch.long).to(self._device_name) + all_order_ids = torch.tensor( + order_ids, dtype=torch.long).to(self._device_name) + all_type_ids = torch.tensor( + type_ids, dtype=torch.long).to(self._device_name) + all_input_mask = torch.tensor( + input_mask, dtype=torch.long).to(self._device_name) + all_segment_ids = torch.tensor( + segment_ids, dtype=torch.long).to(self._device_name) + all_match_ids = torch.tensor( + match_ids, dtype=torch.long).to(self._device_name) + all_header_ids = torch.tensor( + header_ids, dtype=torch.long).to(self._device_name) + all_ids = torch.arange( + all_input_ids.shape[0], dtype=torch.long).to(self._device_name) + + bS = len(header_flatten_tokenid_list) + max_header_flatten_token_length = max( + [len(x) for x in header_flatten_tokenid_list]) + all_header_flatten_tokens = numpy.zeros( + (bS, max_header_flatten_token_length), dtype='int32') + all_header_flatten_index = numpy.zeros( + (bS, max_header_flatten_token_length), dtype='int32') + for i, header_flatten_tokenid in enumerate( + header_flatten_tokenid_list): + for j, tokenid in enumerate(header_flatten_tokenid): + all_header_flatten_tokens[i, j] = tokenid + for j, hdindex in enumerate(header_flatten_index_list[i]): + all_header_flatten_index[i, j] = hdindex + all_header_flatten_output = numpy.zeros((bS, header_max_len + 1), + dtype='int32') + all_header_flatten_tokens = torch.tensor( + all_header_flatten_tokens, dtype=torch.long).to(self._device_name) + all_header_flatten_index = torch.tensor( + all_header_flatten_index, dtype=torch.long).to(self._device_name) + all_header_flatten_output = torch.tensor( + all_header_flatten_output, + dtype=torch.float32).to(self._device_name) + + all_token_column_id = numpy.zeros((bS, cur_max_length), dtype='int32') + all_token_column_mask = numpy.zeros((bS, cur_max_length), + dtype='float32') + for bi, col_dict in enumerate(col_dict_list): + for ki, vi in col_dict.items(): + all_token_column_id[bi, ki] = vi + 1 + all_token_column_mask[bi, ki] = 1.0 + all_token_column_id = torch.tensor( + all_token_column_id, dtype=torch.long).to(self._device_name) + all_token_column_mask = torch.tensor( + all_token_column_mask, dtype=torch.float32).to(self._device_name) + + all_schema_link_matrix = numpy.zeros( + (bS, cur_max_length, cur_max_length), dtype='int32') + all_schema_link_mask = numpy.zeros( + (bS, cur_max_length, cur_max_length), dtype='float32') + for i, schema_link_matrix in enumerate(schema_link_matrix_list): + temp_len = schema_link_matrix.shape[0] + all_schema_link_matrix[i, 0:temp_len, + 0:temp_len] = schema_link_matrix + all_schema_link_mask[i, 0:temp_len, + 0:temp_len] = schema_link_mask_list[i] + all_schema_link_matrix = torch.tensor( + all_schema_link_matrix, dtype=torch.long).to(self._device_name) + all_schema_link_mask = torch.tensor( + all_schema_link_mask, dtype=torch.long).to(self._device_name) + + # 5. generate l_hpu from i_hds + l_hpu = self.gen_l_hpu(i_hds) + + # 4. Generate BERT output. + all_encoder_layer, pooled_output = model_bert( + all_input_ids, + all_header_ids, + token_order_ids=all_order_ids, + token_type_ids=all_segment_ids, + attention_mask=all_input_mask, + match_type_ids=all_match_ids, + l_hs=l_hs, + header_len=header_len, + type_ids=all_type_ids, + col_dict_list=col_dict_list, + ids=all_ids, + header_flatten_tokens=all_header_flatten_tokens, + header_flatten_index=all_header_flatten_index, + header_flatten_output=all_header_flatten_output, + token_column_id=all_token_column_id, + token_column_mask=all_token_column_mask, + column_start_index=column_index, + headers_length=l_hs, + all_schema_link_matrix=all_schema_link_matrix, + all_schema_link_mask=all_schema_link_mask, + output_all_encoded_layers=False) + + return all_encoder_layer, pooled_output, tokens, i_nlu, i_hds, \ + l_n, l_hpu, l_hs, start_index, column_index, all_ids + + def predict(self, querys): + self.head_model.eval() + self.backbone_model.eval() + + nlu, nlu_t, sql_i, q_know, t_know, tb, hs_t, types, units, his_sql, schema_link = \ + self.get_fields_info(querys, None, train=False) + + with torch.no_grad(): + all_encoder_layer, _, tokens, i_nlu, i_hds, l_n, l_hpu, l_hs, start_index, column_index, ids = \ + self.get_bert_output( + self.backbone_model, self.tokenizer, + nlu_t, hs_t, types, units, his_sql, q_know, t_know, schema_link) + + s_action, s_sc, s_sa, s_cco, s_wc, s_wo, s_wvs, s_len = self.head_model( + all_encoder_layer, l_n, l_hs, start_index, column_index, + tokens, ids) + + action_batch = torch.argmax(F.softmax(s_action, -1), -1).cpu().tolist() + scco_batch = torch.argmax(F.softmax(s_cco, -1), -1).cpu().tolist() + sc_batch = torch.argmax(F.softmax(s_sc, -1), -1).cpu().tolist() + sa_batch = torch.argmax(F.softmax(s_sa, -1), -1).cpu().tolist() + wc_batch = torch.argmax(F.softmax(s_wc, -1), -1).cpu().tolist() + wo_batch = torch.argmax(F.softmax(s_wo, -1), -1).cpu().tolist() + s_wvs_s, s_wvs_e = s_wvs + wvss_batch = torch.argmax(F.softmax(s_wvs_s, -1), -1).cpu().tolist() + wvse_batch = torch.argmax(F.softmax(s_wvs_e, -1), -1).cpu().tolist() + s_slen, s_wlen = s_len + slen_batch = torch.argmax(F.softmax(s_slen, -1), -1).cpu().tolist() + wlen_batch = torch.argmax(F.softmax(s_wlen, -1), -1).cpu().tolist() + + pr_wvi = [] + for i in range(len(querys)): + wvi = [] + for j in range(wlen_batch[i]): + wvi.append([ + max(0, wvss_batch[i][j] - 1), + max(0, wvse_batch[i][j] - 1) + ]) + pr_wvi.append(wvi) + pr_wvi_str = self.convert_string(pr_wvi, nlu, nlu_t) + + pre_results = [] + for ib in range(len(querys)): + res_one = {} + sql = {} + sql['cond_conn_op'] = scco_batch[ib] + sl = slen_batch[ib] + sql['sel'] = list( + numpy.array(sc_batch[ib][:sl]).astype(numpy.int32) - 1) + sql['agg'] = list( + numpy.array(sa_batch[ib][:sl]).astype(numpy.int32)) + sels = [] + aggs = [] + for ia, sel in enumerate(sql['sel']): + if sel == -1: + if sql['agg'][ia] > 0: + sels.append(l_hs[ib] - 1) + aggs.append(sql['agg'][ia]) + continue + sels.append(sel) + if sql['agg'][ia] == -1: + aggs.append(0) + else: + aggs.append(sql['agg'][ia]) + if len(sels) == 0: + sels.append(l_hs[ib] - 1) + aggs.append(0) + assert len(sels) == len(aggs) + sql['sel'] = sels + sql['agg'] = aggs + + conds = [] + wl = wlen_batch[ib] + wc_os = list( + numpy.array(wc_batch[ib][:wl]).astype(numpy.int32) - 1) + wo_os = list(numpy.array(wo_batch[ib][:wl]).astype(numpy.int32)) + res_one['question_tok'] = querys[ib]['question_tok'] + for i in range(wl): + if wc_os[i] == -1: + continue + conds.append([wc_os[i], wo_os[i], pr_wvi_str[ib][i]]) + if len(conds) == 0: + conds.append([l_hs[ib] - 1, 2, 'Nulll']) + sql['conds'] = conds + res_one['question'] = querys[ib]['question'] + res_one['table_id'] = querys[ib]['table_id'] + res_one['sql'] = sql + res_one['action'] = action_batch[ib] + res_one['model_out'] = [ + sc_batch[ib], sa_batch[ib], wc_batch[ib], wo_batch[ib], + wvss_batch[ib], wvse_batch[ib] + ] + pre_results.append(res_one) + + return pre_results + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + + Returns: + Dict[str, Tensor]: results + Example: + """ + result = self.predict(input['datas'])[0] + + return { + 'result': result, + 'history_sql': input['datas'][0]['history_sql'] + } diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 8ddeb314..d7d619bf 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -35,6 +35,7 @@ class OutputKeys(object): UUID = 'uuid' WORD = 'word' KWS_LIST = 'kws_list' + HISTORY = 'history' TIMESTAMPS = 'timestamps' SPLIT_VIDEO_NUM = 'split_video_num' SPLIT_META_DICT = 'split_meta_dict' @@ -471,6 +472,13 @@ TASK_OUTPUTS = { # } Tasks.conversational_text_to_sql: [OutputKeys.TEXT], + # table-question-answering result for single sample + # { + # "sql": "SELECT shop.Name FROM shop." + # "sql_history": {sel: 0, agg: 0, conds: [[0, 0, 'val']]} + # } + Tasks.table_question_answering: [OutputKeys.OUTPUT, OutputKeys.HISTORY], + # ============ audio tasks =================== # asr result for single sample # { "text": "每一天都要快乐喔"} diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 50313cf7..5e244b27 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -66,6 +66,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.conversational_text_to_sql: (Pipelines.conversational_text_to_sql, 'damo/nlp_star_conversational-text-to-sql'), + Tasks.table_question_answering: + (Pipelines.table_question_answering_pipeline, + 'damo/nlp-convai-text2sql-pretrain-cn'), Tasks.text_error_correction: (Pipelines.text_error_correction, 'damo/nlp_bart_text-error-correction_chinese'), diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index 6f898c0f..b5c53f82 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -5,6 +5,7 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .conversational_text_to_sql_pipeline import ConversationalTextToSqlPipeline + from .table_question_answering_pipeline import TableQuestionAnsweringPipeline from .dialog_intent_prediction_pipeline import DialogIntentPredictionPipeline from .dialog_modeling_pipeline import DialogModelingPipeline from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline @@ -31,6 +32,8 @@ else: _import_structure = { 'conversational_text_to_sql_pipeline': ['ConversationalTextToSqlPipeline'], + 'table_question_answering_pipeline': + ['TableQuestionAnsweringPipeline'], 'dialog_intent_prediction_pipeline': ['DialogIntentPredictionPipeline'], 'dialog_modeling_pipeline': ['DialogModelingPipeline'], diff --git a/modelscope/pipelines/nlp/table_question_answering_pipeline.py b/modelscope/pipelines/nlp/table_question_answering_pipeline.py new file mode 100644 index 00000000..8235a4d6 --- /dev/null +++ b/modelscope/pipelines/nlp/table_question_answering_pipeline.py @@ -0,0 +1,284 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict, Union + +import torch +from transformers import BertTokenizer + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.models.nlp import TableQuestionAnswering +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import TableQuestionAnsweringPreprocessor +from modelscope.preprocessors.star3.fields.database import Database +from modelscope.preprocessors.star3.fields.struct import Constant, SQLQuery +from modelscope.utils.constant import ModelFile, Tasks + +__all__ = ['TableQuestionAnsweringPipeline'] + + +@PIPELINES.register_module( + Tasks.table_question_answering, + module_name=Pipelines.table_question_answering_pipeline) +class TableQuestionAnsweringPipeline(Pipeline): + + def __init__(self, + model: Union[TableQuestionAnswering, str], + preprocessor: TableQuestionAnsweringPreprocessor = None, + db: Database = None, + **kwargs): + """use `model` and `preprocessor` to create a table question answering prediction pipeline + + Args: + model (TableQuestionAnswering): a model instance + preprocessor (TableQuestionAnsweringPreprocessor): a preprocessor instance + db (Database): a database to store tables in the database + """ + model = model if isinstance( + model, TableQuestionAnswering) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = TableQuestionAnsweringPreprocessor(model.model_dir) + + # initilize tokenizer + self.tokenizer = BertTokenizer( + os.path.join(model.model_dir, ModelFile.VOCAB_FILE)) + + # initialize database + if db is None: + self.db = Database( + tokenizer=self.tokenizer, + table_file_path=os.path.join(model.model_dir, 'table.json'), + syn_dict_file_path=os.path.join(model.model_dir, + 'synonym.txt')) + else: + self.db = db + + constant = Constant() + self.agg_ops = constant.agg_ops + self.cond_ops = constant.cond_ops + self.cond_conn_ops = constant.cond_conn_ops + self.action_ops = constant.action_ops + self.max_select_num = constant.max_select_num + self.max_where_num = constant.max_where_num + self.col_type_dict = constant.col_type_dict + self.schema_link_dict = constant.schema_link_dict + + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def post_process_multi_turn(self, history_sql, result, table): + action = self.action_ops[result['action']] + headers = table['header_name'] + current_sql = result['sql'] + + if history_sql is None: + return current_sql + + if action == 'out_of_scripts': + return history_sql + + elif action == 'switch_table': + return current_sql + + elif action == 'restart': + return current_sql + + elif action == 'firstTurn': + return current_sql + + elif action == 'del_focus': + pre_final_sql = copy.deepcopy(history_sql) + pre_sels = [] + pre_aggs = [] + for idx, seli in enumerate(pre_final_sql['sel']): + if seli not in current_sql['sel']: + pre_sels.append(seli) + pre_aggs.append(pre_final_sql['agg'][idx]) + + if len(pre_sels) < 1: + pre_sels.append(len(headers)) + pre_aggs.append(0) + pre_final_sql['sel'] = pre_sels + pre_final_sql['agg'] = pre_aggs + + final_conds = [] + for condi in pre_final_sql['conds']: + if condi[0] < len(headers): + final_conds.append(condi) + if len(final_conds) < 1: + final_conds.append([len(headers), 2, 'Null']) + pre_final_sql['conds'] = final_conds + + return pre_final_sql + + elif action == 'change_agg_only': + pre_final_sql = history_sql + pre_sels = [] + pre_aggs = [] + for idx, seli in enumerate(pre_final_sql['sel']): + if seli in current_sql['sel']: + pre_sels.append(seli) + changed_aggi = -1 + for idx_single, aggi in enumerate(current_sql['agg']): + if current_sql['sel'][idx_single] == seli: + changed_aggi = aggi + pre_aggs.append(changed_aggi) + else: + pre_sels.append(seli) + pre_aggs.append(pre_final_sql['agg'][idx]) + pre_final_sql['sel'] = pre_sels + pre_final_sql['agg'] = pre_aggs + + return pre_final_sql + + elif action == 'change_focus_total': + pre_final_sql = history_sql + pre_sels = current_sql['sel'] + pre_aggs = current_sql['agg'] + + pre_final_sql['sel'] = pre_sels + pre_final_sql['agg'] = pre_aggs + for pre_condi in current_sql['conds']: + if pre_condi[0] < len(headers): + in_flag = False + for history_condi in history_sql['conds']: + if pre_condi[0] == history_condi[0]: + in_flag = True + if not in_flag: + pre_final_sql['conds'].append(pre_condi) + + return pre_final_sql + + elif action == 'del_cond': + pre_final_sql = copy.deepcopy(history_sql) + + final_conds = [] + + for idx, condi in enumerate(pre_final_sql['conds']): + if condi[0] not in current_sql['sel']: + final_conds.append(condi) + pre_final_sql['conds'] = final_conds + + final_conds = [] + for condi in pre_final_sql['conds']: + if condi[0] < len(headers): + final_conds.append(condi) + if len(final_conds) < 1: + final_conds.append([len(headers), 2, 'Null']) + pre_final_sql['conds'] = final_conds + + return pre_final_sql + + elif action == 'change_cond': + pre_final_sql = history_sql + final_conds = [] + + for idx, condi in enumerate(pre_final_sql['conds']): + in_single_flag = False + for single_condi in current_sql['conds']: + if condi[0] == single_condi[0]: + in_single_flag = True + final_conds.append(single_condi) + if not in_single_flag: + final_conds.append(condi) + pre_final_sql['conds'] = final_conds + + final_conds = [] + for condi in pre_final_sql['conds']: + if condi[0] < len(headers): + final_conds.append(condi) + if len(final_conds) < 1: + final_conds.append([len(headers), 2, 'Null', 'Null']) + pre_final_sql['conds'] = final_conds + + return pre_final_sql + + elif action == 'add_cond': + pre_final_sql = history_sql + final_conds = pre_final_sql['conds'] + for idx, condi in enumerate(current_sql['conds']): + if condi[0] < len(headers): + final_conds.append(condi) + pre_final_sql['conds'] = final_conds + + final_conds = [] + for condi in pre_final_sql['conds']: + if condi[0] < len(headers): + final_conds.append(condi) + if len(final_conds) < 1: + final_conds.append([len(headers), 2, 'Null']) + pre_final_sql['conds'] = final_conds + + return pre_final_sql + + else: + return current_sql + + def sql_dict_to_str(self, result, table): + """ + convert sql struct to string + """ + header_names = table['header_name'] + ['空列'] + header_ids = table['header_id'] + ['null'] + sql = result['sql'] + + str_sel_list, sql_sel_list = [], [] + for idx, sel in enumerate(sql['sel']): + header_name = header_names[sel] + header_id = '`%s`.`%s`' % (table['table_id'], header_ids[sel]) + if sql['agg'][idx] == 0: + str_sel_list.append(header_name) + sql_sel_list.append(header_id) + else: + str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '( ' + + header_name + ' )') + sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '( ' + + header_id + ' )') + + str_cond_list, sql_cond_list = [], [] + for cond in sql['conds']: + header_name = header_names[cond[0]] + header_id = '`%s`.`%s`' % (table['table_id'], header_ids[cond[0]]) + op = self.cond_ops[cond[1]] + value = cond[2] + str_cond_list.append('( ' + header_name + ' ' + op + ' "' + value + + '" )') + sql_cond_list.append('( ' + header_id + ' ' + op + ' "' + value + + '" )') + + cond = ' ' + self.cond_conn_ops[sql['cond_conn_op']] + ' ' + + final_str = 'SELECT %s FROM %s WHERE %s' % (', '.join(str_sel_list), + table['table_name'], + cond.join(str_cond_list)) + final_sql = 'SELECT %s FROM `%s` WHERE %s' % (', '.join(sql_sel_list), + table['table_id'], + cond.join(sql_cond_list)) + sql = SQLQuery( + string=final_str, query=final_sql, sql_result=result['sql']) + + return sql + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + result = inputs['result'] + history_sql = inputs['history_sql'] + result['sql'] = self.post_process_multi_turn( + history_sql=history_sql, + result=result, + table=self.db.tables[result['table_id']]) + sql = self.sql_dict_to_str( + result=result, table=self.db.tables[result['table_id']]) + output = {OutputKeys.OUTPUT: sql, OutputKeys.HISTORY: result['sql']} + return output + + def _collate_fn(self, data): + return data diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 212339ae..04901dc5 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -30,6 +30,7 @@ if TYPE_CHECKING: DialogStateTrackingPreprocessor) from .video import ReadVideoData, MovieSceneSegmentationPreprocessor from .star import ConversationalTextToSqlPreprocessor + from .star3 import TableQuestionAnsweringPreprocessor else: _import_structure = { @@ -62,6 +63,7 @@ else: 'DialogStateTrackingPreprocessor', 'InputFeatures' ], 'star': ['ConversationalTextToSqlPreprocessor'], + 'star3': ['TableQuestionAnsweringPreprocessor'], } import sys diff --git a/modelscope/preprocessors/star3/__init__.py b/modelscope/preprocessors/star3/__init__.py new file mode 100644 index 00000000..9aa562d7 --- /dev/null +++ b/modelscope/preprocessors/star3/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .table_question_answering_preprocessor import TableQuestionAnsweringPreprocessor + from .fields import MultiWOZBPETextField, IntentBPETextField + +else: + _import_structure = { + 'table_question_answering_preprocessor': + ['TableQuestionAnsweringPreprocessor'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/preprocessors/star3/fields/__init__.py b/modelscope/preprocessors/star3/fields/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/preprocessors/star3/fields/database.py b/modelscope/preprocessors/star3/fields/database.py new file mode 100644 index 00000000..a99800cf --- /dev/null +++ b/modelscope/preprocessors/star3/fields/database.py @@ -0,0 +1,77 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import json +import tqdm + +from modelscope.preprocessors.star3.fields.struct import Trie + + +class Database: + + def __init__(self, tokenizer, table_file_path, syn_dict_file_path): + self.tokenizer = tokenizer + self.tables = self.init_tables(table_file_path=table_file_path) + self.syn_dict = self.init_syn_dict( + syn_dict_file_path=syn_dict_file_path) + + def init_tables(self, table_file_path): + tables = {} + lines = [] + with open(table_file_path, 'r') as fo: + for line in fo: + lines.append(line) + + for line in tqdm.tqdm(lines, desc='Load Tables'): + table = json.loads(line.strip()) + + table_header_length = 0 + headers_tokens = [] + for header in table['header_name']: + header_tokens = self.tokenizer.tokenize(header) + table_header_length += len(header_tokens) + headers_tokens.append(header_tokens) + empty_column = self.tokenizer.tokenize('空列') + table_header_length += len(empty_column) + headers_tokens.append(empty_column) + table['tablelen'] = table_header_length + table['header_tok'] = headers_tokens + + table['header_types'].append('null') + table['header_units'] = [ + self.tokenizer.tokenize(unit) for unit in table['header_units'] + ] + [[]] + + trie_set = [Trie() for _ in table['header_name']] + for row in table['rows']: + for ii, cell in enumerate(row): + if 'real' in table['header_types'][ii].lower() or \ + 'number' in table['header_types'][ii].lower() or \ + 'duration' in table['header_types'][ii].lower(): + continue + word = str(cell).strip().lower() + trie_set[ii].insert(word, word) + + table['value_trie'] = trie_set + tables[table['table_id']] = table + + return tables + + def init_syn_dict(self, syn_dict_file_path): + lines = [] + with open(syn_dict_file_path, encoding='utf-8') as fo: + for line in fo: + lines.append(line) + + syn_dict = {} + for line in tqdm.tqdm(lines, desc='Load Synonym Dict'): + tokens = line.strip().split('\t') + if len(tokens) != 2: + continue + keys = tokens[0].strip().split('|') + values = tokens[1].strip().split('|') + for key in keys: + key = key.lower().strip() + syn_dict.setdefault(key, []) + for value in values: + syn_dict[key].append(value.lower().strip()) + + return syn_dict diff --git a/modelscope/preprocessors/star3/fields/schema_link.py b/modelscope/preprocessors/star3/fields/schema_link.py new file mode 100644 index 00000000..40613f78 --- /dev/null +++ b/modelscope/preprocessors/star3/fields/schema_link.py @@ -0,0 +1,423 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import re + +from modelscope.preprocessors.star3.fields.struct import TypeInfo + + +class SchemaLinker: + + def __init__(self): + pass + + def find_in_list(self, comlist, words): + result = False + for com in comlist: + if words in com: + result = True + break + return result + + def get_continue_score(self, pstr, tstr): + comlist = [] + minlen = min(len(pstr), len(tstr)) + for slen in range(minlen, 1, -1): + for ts in range(0, len(tstr), 1): + if ts + slen > len(tstr): + continue + words = tstr[ts:ts + slen] + if words in pstr and not self.find_in_list(comlist, words): + comlist.append(words) + + comlen = 0 + for com in comlist: + comlen += len(com) * len(com) + weight = comlen / (len(tstr) * len(tstr) + 0.001) + if weight > 1.0: + weight = 1.0 + + return weight + + def get_match_score(self, ptokens, ttokens): + pset = set(ptokens) + tset = set(ttokens) + comset = pset & tset + allset = pset | tset + weight2 = len(comset) / (len(allset) + 0.001) + weight3 = self.get_continue_score(''.join(ptokens), ''.join(ttokens)) + return 0.4 * weight2 + 0.6 * weight3 + + def is_number(self, s): + try: + float(s) + return True + except ValueError: + pass + + try: + import unicodedata + unicodedata.numeric(s) + return True + except (TypeError, ValueError): + pass + + return False + + def get_match_phrase(self, query, target): + if target in query: + return target, 1.0 + + qtokens = [] + for i in range(0, len(query), 1): + qtokens.append(query[i:i + 1]) + ttokens = [] + for i in range(0, len(target), 1): + ttokens.append(target[i:i + 1]) + ttok_set = set(ttokens) + + phrase = '' + score = 0.0 + for qidx, qword in enumerate(qtokens): + if qword not in ttok_set: + continue + + eidx = (qidx + 2 * len(ttokens)) if ( + len(qtokens) > qidx + 2 * len(ttokens)) else len(qtokens) + while eidx > qidx: + ptokens = qtokens[qidx:eidx] + weight = self.get_match_score(ptokens, ttokens) + if weight + 0.001 > score: + score = weight + phrase = ''.join(ptokens) + eidx -= 1 + + if self.is_number(target) and phrase != target: + score = 0.0 + if len(phrase) > 1 and phrase in target: + score *= (1.0 + 0.05 * len(phrase)) + + return phrase, score + + def allfindpairidx(self, que_tok, value_tok, weight): + idxs = [] + for i in range(0, len(que_tok) - len(value_tok) + 1, 1): + s = i + e = i + matched = True + for j in range(0, len(value_tok), 1): + if value_tok[j].lower() == que_tok[i + j].lower(): + e = i + j + else: + matched = False + break + if matched: + idxs.append([s, e, weight]) + + return idxs + + def findnear(self, ps1, pe1, ps2, pe2): + if abs(ps1 - pe2) <= 2 or abs(pe1 - ps2) <= 2: + return True + return False + + def get_column_type(self, col_idx, table): + colType = table['header_types'][col_idx] + if 'number' in colType or 'duration' in colType or 'real' in colType: + colType = 'real' + elif 'date' in colType: + colType = 'date' + elif 'bool' in colType: + colType = 'bool' + else: + colType = 'text' + + return colType + + def add_type_all(self, typeinfos, index, idxs, label, linktype, value, + orgvalue): + for idx in idxs: + info = TypeInfo(label, index, linktype, value, orgvalue, idx[0], + idx[1], idx[2]) + flag = True + for i, typeinfo in enumerate(typeinfos): + if info.pstart < typeinfo.pstart: + typeinfos.insert(i, info) + flag = False + break + + if flag: + typeinfos.append(info) + + return typeinfos + + def save_info(self, tinfo, sinfo): + flag = True + if tinfo.pstart > sinfo.pend or tinfo.pend < sinfo.pstart: + pass + elif tinfo.pstart >= sinfo.pstart and \ + tinfo.pend <= sinfo.pend and tinfo.index == -1: + flag = False + elif tinfo.pstart == sinfo.pstart and sinfo.pend == tinfo.pend and \ + abs(tinfo.weight - sinfo.weight) < 0.01: + pass + else: + if sinfo.label == 'col' or sinfo.label == 'val': + if tinfo.label == 'col' or tinfo.label == 'val': + if (sinfo.pend + - sinfo.pstart) > (tinfo.pend - tinfo.pstart) or ( + sinfo.weight > tinfo.weight + and sinfo.index != -1): + flag = False + else: + flag = False + else: + if (tinfo.label == 'op' or tinfo.label == 'agg'): + if (sinfo.pend - sinfo.pstart) > ( + tinfo.pend + - tinfo.pstart) or sinfo.weight > tinfo.weight: + flag = False + + return flag + + def normal_type_infos(self, infos): + typeinfos = [] + for info in infos: + typeinfos = [x for x in typeinfos if self.save_info(x, info)] + flag = True + for i, typeinfo in enumerate(typeinfos): + if not self.save_info(info, typeinfo): + flag = False + break + if info.pstart < typeinfo.pstart: + typeinfos.insert(i, info) + flag = False + break + if flag: + typeinfos.append(info) + return typeinfos + + def findnear_typeinfo(self, info1, info2): + return self.findnear(info1.pstart, info1.pend, info2.pstart, + info2.pend) + + def find_real_column(self, infos, table): + for i, vinfo in enumerate(infos): + if vinfo.index != -1 or vinfo.label != 'val': + continue + eoidx = -1 + for j, oinfo in enumerate(infos): + if oinfo.label != 'op': + continue + if self.findnear_typeinfo(vinfo, oinfo): + eoidx = j + break + for j, cinfo in enumerate(infos): + if cinfo.label != 'col' or table['header_types'][ + cinfo.index] != 'real': + continue + if self.findnear_typeinfo(cinfo, vinfo) or ( + eoidx != -1 + and self.findnear_typeinfo(cinfo, infos[eoidx])): + infos[i].index = cinfo.index + break + + return infos + + def filter_column_infos(self, infos): + delid = [] + for i, info in enumerate(infos): + if info.label != 'col': + continue + for j in range(i + 1, len(infos), 1): + if infos[j].label == 'col' and \ + info.pstart == infos[j].pstart and \ + info.pend == infos[j].pend: + delid.append(i) + delid.append(j) + break + + typeinfos = [] + for idx, info in enumerate(infos): + if idx in set(delid): + continue + typeinfos.append(info) + + return typeinfos + + def filter_type_infos(self, infos, table): + infos = self.filter_column_infos(infos) + infos = self.find_real_column(infos, table) + + colvalMp = {} + for info in infos: + if info.label == 'col': + colvalMp[info.index] = [] + for info in infos: + if info.label == 'val' and info.index in colvalMp: + colvalMp[info.index].append(info) + + delid = [] + for idx, info in enumerate(infos): + if info.label != 'val' or info.index in colvalMp: + continue + for index in colvalMp.keys(): + valinfos = colvalMp[index] + for valinfo in valinfos: + if valinfo.pstart <= info.pstart and \ + valinfo.pend >= info.pend: + delid.append(idx) + break + + typeinfos = [] + for idx, info in enumerate(infos): + if idx in set(delid): + continue + typeinfos.append(info) + + return typeinfos + + def get_table_match_score(self, nlu_t, schema_link): + match_len = 0 + for info in schema_link: + scale = 0.6 + if info['question_len'] > 0 and info['column_index'] != -1: + scale = 1.0 + else: + scale = 0.5 + match_len += scale * info['question_len'] * info['weight'] + + return match_len / (len(nlu_t) + 0.1) + + def get_entity_linking(self, tokenizer, nlu, nlu_t, tables, col_syn_dict): + """ + get linking between question and schema column + """ + typeinfos = [] + numbers = re.findall(r'[-]?\d*\.\d+|[-]?\d+|\d+', nlu) + + # search schema link in every table + search_result_list = [] + for tablename in tables: + table = tables[tablename] + trie_set = None + if 'value_trie' in table: + trie_set = table['value_trie'] + + typeinfos = [] + for ii, column in enumerate(table['header_name']): + column = column.lower() + column_new = re.sub('(.*?)', '', column) + column_new = re.sub('(.*?)', '', column_new) + cphrase, cscore = self.get_match_phrase( + nlu.lower(), column_new) + if cscore > 0.3 and cphrase.strip() != '': + phrase_tok = tokenizer.tokenize(cphrase) + cidxs = self.allfindpairidx(nlu_t, phrase_tok, cscore) + typeinfos = self.add_type_all(typeinfos, ii, cidxs, 'col', + 'column', cphrase, column) + if cscore < 0.8 and column_new in col_syn_dict: + columns = list(set(col_syn_dict[column_new])) + for syn_col in columns: + if syn_col not in nlu.lower() or syn_col == '': + continue + phrase_tok = tokenizer.tokenize(syn_col) + cidxs = self.allfindpairidx(nlu_t, phrase_tok, 1.0) + typeinfos = self.add_type_all(typeinfos, ii, cidxs, + 'col', 'column', syn_col, + column) + + for ii, trie in enumerate(trie_set): + ans = trie.match(nlu.lower()) + for cell in ans.keys(): + vphrase = cell + vscore = 1.0 + # print("trie_set find:", cell, ans[cell]) + phrase_tok = tokenizer.tokenize(vphrase) + if len(phrase_tok) == 0 or len(vphrase) < 2: + continue + vidxs = self.allfindpairidx(nlu_t, phrase_tok, vscore) + linktype = self.get_column_type(ii, table) + typeinfos = self.add_type_all(typeinfos, ii, vidxs, 'val', + linktype, vphrase, ans[cell]) + + for number in set(numbers): + number_tok = tokenizer.tokenize(number.lower()) + if len(number_tok) == 0: + continue + nidxs = self.allfindpairidx(nlu_t, number_tok, 1.0) + typeinfos = self.add_type_all(typeinfos, -1, nidxs, 'val', + 'real', number, number) + + newtypeinfos = self.normal_type_infos(typeinfos) + + newtypeinfos = self.filter_type_infos(newtypeinfos, table) + + final_question = [0] * len(nlu_t) + final_header = [0] * len(table['header_name']) + for typeinfo in newtypeinfos: + pstart = typeinfo.pstart + pend = typeinfo.pend + 1 + if typeinfo.label == 'op' or typeinfo.label == 'agg': + score = int(typeinfo.linktype[-1]) + if typeinfo.label == 'op': + score += 6 + else: + score += 11 + for i in range(pstart, pend, 1): + final_question[i] = score + + elif typeinfo.label == 'col': + for i in range(pstart, pend, 1): + final_question[i] = 4 + if final_header[typeinfo.index] % 2 == 0: + final_header[typeinfo.index] += 1 + + elif typeinfo.label == 'val': + if typeinfo.index == -1: + for i in range(pstart, pend, 1): + final_question[i] = 5 + else: + for i in range(pstart, pend, 1): + final_question[i] = 2 + final_question[pstart] = 1 + final_question[pend - 1] = 3 + if final_header[typeinfo.index] < 2: + final_header[typeinfo.index] += 2 + + # collect schema_link + schema_link = [] + for sl in newtypeinfos: + if sl.label in ['val', 'col']: + schema_link.append({ + 'question_len': + max(0, sl.pend - sl.pstart + 1), + 'question_index': [sl.pstart, sl.pend], + 'question_span': + ''.join(nlu_t[sl.pstart:sl.pend + 1]), + 'column_index': + sl.index, + 'column_span': + table['header_name'][sl.index] + if sl.index != -1 else '空列', + 'label': + sl.label, + 'weight': + round(sl.weight, 4) + }) + + # get the match score of each table + match_score = self.get_table_match_score(nlu_t, schema_link) + + search_result = { + 'table_id': table['table_id'], + 'question_knowledge': final_question, + 'header_knowledge': final_header, + 'schema_link': schema_link, + 'match_score': match_score + } + search_result_list.append(search_result) + + search_result_list = sorted( + search_result_list, key=lambda x: x['match_score'], + reverse=True)[0:4] + + return search_result_list diff --git a/modelscope/preprocessors/star3/fields/struct.py b/modelscope/preprocessors/star3/fields/struct.py new file mode 100644 index 00000000..3c2e664b --- /dev/null +++ b/modelscope/preprocessors/star3/fields/struct.py @@ -0,0 +1,181 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +cond_ops = ['>', '<', '==', '!=', 'ASC', 'DESC'] +agg_ops = [ + '', 'AVG', 'MAX', 'MIN', 'COUNT', 'SUM', 'COMPARE', 'GROUP BY', 'SAME' +] +conn_ops = ['', 'AND', 'OR'] + + +class Context: + + def __init__(self): + self.history_sql = None + + def set_history_sql(self, sql): + self.history_sql = sql + + +class SQLQuery: + + def __init__(self, string, query, sql_result): + self.string = string + self.query = query + self.sql_result = sql_result + + +class TrieNode(object): + + def __init__(self): + """ + Initialize your data structure here. + """ + self.data = {} + self.is_word = False + self.term = None + + +class Trie(object): + + def __init__(self): + self.root = TrieNode() + + def insert(self, word, term): + """ + Inserts a word into the trie. + :type word: str + :rtype: void + """ + node = self.root + for letter in word: + child = node.data.get(letter) + if not child: + node.data[letter] = TrieNode() + node = node.data[letter] + node.is_word = True + node.term = term + + def search(self, word): + """ + Returns if the word is in the trie. + :type word: str + :rtype: bool + """ + node = self.root + for letter in word: + node = node.data.get(letter) + if not node: + return None, False + return node.term, True + + def match(self, query): + start = 0 + end = 1 + length = len(query) + ans = {} + while start < length and end < length: + sub = query[start:end] + term, flag = self.search(sub) + if flag: + if term is not None: + ans[sub] = term + end += 1 + else: + start += 1 + end = start + 1 + return ans + + def starts_with(self, prefix): + """ + Returns if there is any word in the trie + that starts with the given prefix. + :type prefix: str + :rtype: bool + """ + node = self.root + for letter in prefix: + node = node.data.get(letter) + if not node: + return False + return True + + def get_start(self, prefix): + """ + Returns words started with prefix + :param prefix: + :return: words (list) + """ + + def _get_key(pre, pre_node): + words_list = [] + if pre_node.is_word: + words_list.append(pre) + for x in pre_node.data.keys(): + words_list.extend(_get_key(pre + str(x), pre_node.data.get(x))) + return words_list + + words = [] + if not self.starts_with(prefix): + return words + if self.search(prefix): + words.append(prefix) + return words + node = self.root + for letter in prefix: + node = node.data.get(letter) + return _get_key(prefix, node) + + +class TypeInfo: + + def __init__(self, label, index, linktype, value, orgvalue, pstart, pend, + weight): + self.label = label + self.index = index + self.linktype = linktype + self.value = value + self.orgvalue = orgvalue + self.pstart = pstart + self.pend = pend + self.weight = weight + + +class Constant: + + def __init__(self): + self.action_ops = [ + 'add_cond', 'change_cond', 'del_cond', 'change_focus_total', + 'change_agg_only', 'del_focus', 'restart', 'switch_table', + 'out_of_scripts', 'repeat', 'firstTurn' + ] + + self.agg_ops = [ + '', 'AVG', 'MAX', 'MIN', 'COUNT', 'SUM', 'COMPARE', 'GROUP BY', + 'SAME' + ] + + self.cond_ops = ['>', '<', '==', '!=', 'ASC', 'DESC'] + + self.cond_conn_ops = ['', 'AND', 'OR'] + + self.col_type_dict = { + 'null': 0, + 'text': 1, + 'number': 2, + 'duration': 3, + 'bool': 4, + 'date': 5 + } + + self.schema_link_dict = { + 'col_start': 1, + 'col_middle': 2, + 'col_end': 3, + 'val_start': 4, + 'val_middle': 5, + 'val_end': 6 + } + + self.max_select_num = 4 + + self.max_where_num = 6 diff --git a/modelscope/preprocessors/star3/table_question_answering_preprocessor.py b/modelscope/preprocessors/star3/table_question_answering_preprocessor.py new file mode 100644 index 00000000..163759a1 --- /dev/null +++ b/modelscope/preprocessors/star3/table_question_answering_preprocessor.py @@ -0,0 +1,118 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict + +import torch +from transformers import BertTokenizer + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.preprocessors.star3.fields.database import Database +from modelscope.preprocessors.star3.fields.schema_link import SchemaLinker +from modelscope.utils.config import Config +from modelscope.utils.constant import Fields, ModelFile +from modelscope.utils.type_assert import type_assert + +__all__ = ['TableQuestionAnsweringPreprocessor'] + + +@PREPROCESSORS.register_module( + Fields.nlp, + module_name=Preprocessors.table_question_answering_preprocessor) +class TableQuestionAnsweringPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, db: Database = None, *args, **kwargs): + """preprocess the data + + Args: + model_dir (str): model path + db (Database): database instance + """ + super().__init__(*args, **kwargs) + + self.model_dir: str = model_dir + self.config = Config.from_file( + os.path.join(self.model_dir, ModelFile.CONFIGURATION)) + + # read tokenizer + self.tokenizer = BertTokenizer( + os.path.join(self.model_dir, ModelFile.VOCAB_FILE)) + + # read database + if db is None: + self.db = Database( + tokenizer=self.tokenizer, + table_file_path=os.path.join(self.model_dir, 'table.json'), + syn_dict_file_path=os.path.join(self.model_dir, 'synonym.txt')) + else: + self.db = db + + # get schema linker + self.schema_linker = SchemaLinker() + + # set device + self.device = 'cuda' if \ + ('device' not in kwargs or kwargs['device'] == 'gpu') \ + and torch.cuda.is_available() else 'cpu' + + def construct_data(self, search_result_list, nlu, nlu_t, db, history_sql): + datas = [] + for search_result in search_result_list: + data = {} + data['table_id'] = search_result['table_id'] + data['question'] = nlu + data['question_tok'] = nlu_t + data['header_tok'] = db.tables[data['table_id']]['header_tok'] + data['types'] = db.tables[data['table_id']]['header_types'] + data['units'] = db.tables[data['table_id']]['header_units'] + data['action'] = 0 + data['sql'] = None + data['history_sql'] = history_sql + data['wvi_corenlp'] = [] + data['bertindex_knowledge'] = search_result['question_knowledge'] + data['header_knowledge'] = search_result['header_knowledge'] + data['schema_link'] = search_result['schema_link'] + datas.append(data) + + return datas + + @type_assert(object, dict) + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + """process the raw input data + + Args: + data (dict): + utterance: a sentence + last_sql: predicted sql of last utterance + Example: + utterance: 'Which of these are hiring?' + last_sql: '' + + Returns: + Dict[str, Any]: the preprocessed data + """ + + # tokenize question + question = data['question'] + history_sql = data['history_sql'] + nlu = question.lower() + nlu_t = self.tokenizer.tokenize(nlu) + + # get linking + search_result_list = self.schema_linker.get_entity_linking( + tokenizer=self.tokenizer, + nlu=nlu, + nlu_t=nlu_t, + tables=self.db.tables, + col_syn_dict=self.db.syn_dict) + + # collect data + datas = self.construct_data( + search_result_list=search_result_list[0:1], + nlu=nlu, + nlu_t=nlu_t, + db=self.db, + history_sql=history_sql) + + return {'datas': datas} diff --git a/modelscope/utils/nlp/nlp_utils.py b/modelscope/utils/nlp/nlp_utils.py index af539dda..0b0ea61d 100644 --- a/modelscope/utils/nlp/nlp_utils.py +++ b/modelscope/utils/nlp/nlp_utils.py @@ -3,7 +3,8 @@ from typing import List from modelscope.outputs import OutputKeys from modelscope.pipelines.nlp import (ConversationalTextToSqlPipeline, - DialogStateTrackingPipeline) + DialogStateTrackingPipeline, + TableQuestionAnsweringPipeline) def text2sql_tracking_and_print_results( @@ -42,3 +43,17 @@ def tracking_and_print_dialog_states( print(json.dumps(result)) history_states.extend([result[OutputKeys.OUTPUT], {}]) + + +def tableqa_tracking_and_print_results( + test_case, pipelines: List[TableQuestionAnsweringPipeline]): + for pipeline in pipelines: + historical_queries = None + for question in test_case['utterance']: + output_dict = pipeline({ + 'question': question, + 'history_sql': historical_queries + }) + print('output_dict', output_dict['output'].string, + output_dict['output'].query) + historical_queries = output_dict['history'] diff --git a/tests/pipelines/test_table_question_answering.py b/tests/pipelines/test_table_question_answering.py new file mode 100644 index 00000000..3c416cd5 --- /dev/null +++ b/tests/pipelines/test_table_question_answering.py @@ -0,0 +1,76 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import unittest +from typing import List + +from transformers import BertTokenizer + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import TableQuestionAnsweringPipeline +from modelscope.preprocessors import TableQuestionAnsweringPreprocessor +from modelscope.preprocessors.star3.fields.database import Database +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.nlp.nlp_utils import tableqa_tracking_and_print_results +from modelscope.utils.test_utils import test_level + + +class TableQuestionAnswering(unittest.TestCase): + + def setUp(self) -> None: + self.task = Tasks.table_question_answering + self.model_id = 'damo/nlp_convai_text2sql_pretrain_cn' + + model_id = 'damo/nlp_convai_text2sql_pretrain_cn' + test_case = { + 'utterance': + ['长江流域的小(2)型水库的库容总量是多少?', '那平均值是多少?', '那水库的名称呢?', '换成中型的呢?'] + } + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + preprocessor = TableQuestionAnsweringPreprocessor(model_dir=cache_path) + pipelines = [ + TableQuestionAnsweringPipeline( + model=cache_path, preprocessor=preprocessor) + ] + tableqa_tracking_and_print_results(self.test_case, pipelines) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + preprocessor = TableQuestionAnsweringPreprocessor( + model_dir=model.model_dir) + pipelines = [ + TableQuestionAnsweringPipeline( + model=model, preprocessor=preprocessor) + ] + tableqa_tracking_and_print_results(self.test_case, pipelines) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_task(self): + pipelines = [pipeline(Tasks.table_question_answering, self.model_id)] + tableqa_tracking_and_print_results(self.test_case, pipelines) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_model_from_modelhub_with_other_classes(self): + model = Model.from_pretrained(self.model_id) + self.tokenizer = BertTokenizer( + os.path.join(model.model_dir, ModelFile.VOCAB_FILE)) + db = Database( + tokenizer=self.tokenizer, + table_file_path=os.path.join(model.model_dir, 'table.json'), + syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt')) + preprocessor = TableQuestionAnsweringPreprocessor( + model_dir=model.model_dir, db=db) + pipelines = [ + TableQuestionAnsweringPipeline( + model=model, preprocessor=preprocessor, db=db) + ] + tableqa_tracking_and_print_results(self.test_case, pipelines) + + +if __name__ == '__main__': + unittest.main()