diff --git a/tests/trainers/test_text_generation_trainer.py b/tests/trainers/test_text_generation_trainer.py index 9c79f2f5..8921ecfa 100644 --- a/tests/trainers/test_text_generation_trainer.py +++ b/tests/trainers/test_text_generation_trainer.py @@ -50,17 +50,11 @@ class TestTextGenerationTrainer(unittest.TestCase): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_trainer(self): - def cfg_modify_fn(cfg): - cfg.preprocessor.type = 'text-gen-tokenizer' - return cfg - kwargs = dict( model=self.model_id, train_dataset=self.dataset, eval_dataset=self.dataset, - work_dir=self.tmp_dir, - cfg_modify_fn=cfg_modify_fn, - model_revision='beta') + work_dir=self.tmp_dir) trainer = build_trainer( name='NlpEpochBasedTrainer', default_args=kwargs) @@ -76,7 +70,7 @@ class TestTextGenerationTrainer(unittest.TestCase): if not os.path.exists(tmp_dir): os.makedirs(tmp_dir) - cache_path = snapshot_download(self.model_id, revision='beta') + cache_path = snapshot_download(self.model_id) model = PalmForTextGeneration.from_pretrained(cache_path) kwargs = dict( cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),