|
|
@@ -346,3 +346,6 @@ class GPT3Model(PreTrainedModel): |
|
|
} |
|
|
} |
|
|
model.load_state_dict(state_dict) |
|
|
model.load_state_dict(state_dict) |
|
|
return model |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, *args, **kwargs): |
|
|
|
|
|
return {'input_ids': input_ids} |