|
|
|
@@ -51,12 +51,9 @@ class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel): |
|
|
|
return addict.Dict(outputs) |
|
|
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): |
|
|
|
token_type_ids = kwargs.get('token_type_ids', None) |
|
|
|
# only last token for inputs_ids if past is defined in kwargs |
|
|
|
if past: |
|
|
|
input_ids = input_ids[:, -1].unsqueeze(-1) |
|
|
|
if token_type_ids is not None: |
|
|
|
token_type_ids = token_type_ids[:, -1].unsqueeze(-1) |
|
|
|
|
|
|
|
attention_mask = kwargs.get('attention_mask', None) |
|
|
|
position_ids = kwargs.get('position_ids', None) |
|
|
|
@@ -75,5 +72,8 @@ class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel): |
|
|
|
'use_cache': kwargs.get('use_cache'), |
|
|
|
'position_ids': position_ids, |
|
|
|
'attention_mask': attention_mask, |
|
|
|
'token_type_ids': token_type_ids, |
|
|
|
} |
|
|
|
|
|
|
|
def generate(self, inputs, *args, **kwargs): |
|
|
|
input_ids = inputs['input_ids'] if isinstance(inputs, Dict) else inputs |
|
|
|
return super().generate(input_ids, *args, **kwargs) |