diff --git a/maas_lib/trainers/nlp/sequence_classification_trainer.py b/maas_lib/trainers/nlp/sequence_classification_trainer.py index e88eb95e..f2264c0d 100644 --- a/maas_lib/trainers/nlp/sequence_classification_trainer.py +++ b/maas_lib/trainers/nlp/sequence_classification_trainer.py @@ -128,7 +128,7 @@ class SequenceClassificationTrainer(BaseTrainer): collate_fn=pre_dataset.batch_fn) # generate a model - model = SequenceClassification(checkpoint_path) + model = SequenceClassification.from_pretrained(checkpoint_path) # copy from easynlp (start) model.eval()