diff --git a/modelscope/models/nlp/gpt3/modeling_gpt3.py b/modelscope/models/nlp/gpt3/modeling_gpt3.py index 498d15de..ade36e36 100644 --- a/modelscope/models/nlp/gpt3/modeling_gpt3.py +++ b/modelscope/models/nlp/gpt3/modeling_gpt3.py @@ -346,3 +346,6 @@ class GPT3Model(PreTrainedModel): } model.load_state_dict(state_dict) return model + + def prepare_inputs_for_generation(self, input_ids, *args, **kwargs): + return {'input_ids': input_ids} diff --git a/requirements/nlp.txt b/requirements/nlp.txt index 2e0838fc..123c238e 100644 --- a/requirements/nlp.txt +++ b/requirements/nlp.txt @@ -14,5 +14,4 @@ spacy>=2.3.5 subword_nmt>=0.3.8 text2sql_lgesql tokenizers -# recent 4.23.1 update introduce breaking api change, limit upper version temporarily. -transformers>=4.12.0,<=4.22.0 +transformers