|
|
|
@@ -31,24 +31,23 @@ class SbertModel(TorchModel, SbertModelTransform): |
|
|
|
def extract_pooled_outputs(self, outputs): |
|
|
|
return outputs['pooler_output'] |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
input_ids=None, |
|
|
|
attention_mask=None, |
|
|
|
token_type_ids=None, |
|
|
|
position_ids=None, |
|
|
|
head_mask=None, |
|
|
|
inputs_embeds=None, |
|
|
|
encoder_hidden_states=None, |
|
|
|
encoder_attention_mask=None, |
|
|
|
past_key_values=None, |
|
|
|
use_cache=None, |
|
|
|
output_attentions=None, |
|
|
|
output_hidden_states=None, |
|
|
|
return_dict=None, |
|
|
|
): |
|
|
|
def forward(self, |
|
|
|
input_ids=None, |
|
|
|
attention_mask=None, |
|
|
|
token_type_ids=None, |
|
|
|
position_ids=None, |
|
|
|
head_mask=None, |
|
|
|
inputs_embeds=None, |
|
|
|
encoder_hidden_states=None, |
|
|
|
encoder_attention_mask=None, |
|
|
|
past_key_values=None, |
|
|
|
use_cache=None, |
|
|
|
output_attentions=None, |
|
|
|
output_hidden_states=None, |
|
|
|
return_dict=None, |
|
|
|
**kwargs): |
|
|
|
return SbertModelTransform.forward( |
|
|
|
self, input_ids, attention_mask, token_type_ids, position_ids, |
|
|
|
head_mask, inputs_embeds, encoder_hidden_states, |
|
|
|
encoder_attention_mask, past_key_values, use_cache, |
|
|
|
output_attentions, output_hidden_states, return_dict) |
|
|
|
output_attentions, output_hidden_states, return_dict, **kwargs) |