|
|
|
@@ -17,21 +17,15 @@ |
|
|
|
|
|
|
|
from __future__ import absolute_import, division, print_function |
|
|
|
import copy |
|
|
|
import logging |
|
|
|
import math |
|
|
|
import os |
|
|
|
import shutil |
|
|
|
import tarfile |
|
|
|
import tempfile |
|
|
|
from pathlib import Path |
|
|
|
from typing import Union |
|
|
|
|
|
|
|
import json |
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
import torch_scatter |
|
|
|
from torch import nn |
|
|
|
from torch.nn import CrossEntropyLoss |
|
|
|
|
|
|
|
from modelscope.models.nlp.star3.configuration_star3 import Star3Config |
|
|
|
from modelscope.utils.constant import ModelFile |
|
|
|
@@ -121,33 +115,17 @@ class BertEmbeddings(nn.Module): |
|
|
|
words_embeddings = self.word_embeddings(input_ids) |
|
|
|
header_embeddings = self.word_embeddings(header_ids) |
|
|
|
|
|
|
|
# header mean pooling |
|
|
|
header_flatten_embeddings = self.word_embeddings(header_flatten_tokens) |
|
|
|
header_flatten_index = header_flatten_index.reshape( |
|
|
|
(-1, header_flatten_index.shape[1], 1)) |
|
|
|
header_flatten_index = header_flatten_index.repeat( |
|
|
|
1, 1, header_flatten_embeddings.shape[2]) |
|
|
|
header_flatten_output = header_flatten_output.reshape( |
|
|
|
(-1, header_flatten_output.shape[1], 1)) |
|
|
|
header_flatten_output = header_flatten_output.repeat( |
|
|
|
1, 1, header_flatten_embeddings.shape[2]) |
|
|
|
header_embeddings = torch_scatter.scatter_mean( |
|
|
|
header_flatten_embeddings, |
|
|
|
header_flatten_index, |
|
|
|
out=header_flatten_output, |
|
|
|
dim=1) |
|
|
|
token_column_id = token_column_id.reshape( |
|
|
|
(-1, token_column_id.shape[1], 1)) |
|
|
|
token_column_id = token_column_id.repeat( |
|
|
|
(1, 1, header_embeddings.shape[2])) |
|
|
|
token_column_mask = token_column_mask.reshape( |
|
|
|
(-1, token_column_mask.shape[1], 1)) |
|
|
|
token_column_mask = token_column_mask.repeat( |
|
|
|
(1, 1, header_embeddings.shape[2])) |
|
|
|
token_header_embeddings = torch.gather(header_embeddings, 1, |
|
|
|
token_column_id) |
|
|
|
words_embeddings = words_embeddings * (1.0 - token_column_mask) + \ |
|
|
|
token_header_embeddings * token_column_mask |
|
|
|
if col_dict_list is not None and l_hs is not None: |
|
|
|
col_dict_list = np.array(col_dict_list)[ids.cpu().numpy()].tolist() |
|
|
|
header_len = np.array( |
|
|
|
header_len, dtype=object)[ids.cpu().numpy()].tolist() |
|
|
|
for bi, col_dict in enumerate(col_dict_list): |
|
|
|
for ki, vi in col_dict.items(): |
|
|
|
length = header_len[bi][vi] |
|
|
|
if length == 0: |
|
|
|
continue |
|
|
|
words_embeddings[bi, ki, :] = torch.mean( |
|
|
|
header_embeddings[bi, vi, :length, :], dim=0) |
|
|
|
|
|
|
|
position_embeddings = self.position_embeddings(position_ids) |
|
|
|
token_type_embeddings = self.token_type_embeddings(token_type_ids) |
|
|
|
|