zhangzhicheng.zzc wenmeng.zwm 3 years ago
parent
commit
ceee95f763
2 changed files with 31 additions and 33 deletions
  1. +16
    -17
      modelscope/models/nlp/backbones/structbert.py
  2. +15
    -16
      modelscope/models/nlp/structbert/modeling_sbert.py

+ 16
- 17
modelscope/models/nlp/backbones/structbert.py View File

@@ -31,24 +31,23 @@ class SbertModel(TorchModel, SbertModelTransform):
def extract_pooled_outputs(self, outputs): def extract_pooled_outputs(self, outputs):
return outputs['pooler_output'] return outputs['pooler_output']


def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
def forward(self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs):
return SbertModelTransform.forward( return SbertModelTransform.forward(
self, input_ids, attention_mask, token_type_ids, position_ids, self, input_ids, attention_mask, token_type_ids, position_ids,
head_mask, inputs_embeds, encoder_hidden_states, head_mask, inputs_embeds, encoder_hidden_states,
encoder_attention_mask, past_key_values, use_cache, encoder_attention_mask, past_key_values, use_cache,
output_attentions, output_hidden_states, return_dict)
output_attentions, output_hidden_states, return_dict, **kwargs)

+ 15
- 16
modelscope/models/nlp/structbert/modeling_sbert.py View File

@@ -870,22 +870,21 @@ class SbertModel(SbertPreTrainedModel):
output_type=BaseModelOutputWithPoolingAndCrossAttentions, output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
def forward(self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs):
r""" r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
`optional`): `optional`):


Loading…
Cancel
Save