diff --git a/model_zoo/official/recommend/deepfm/export.py b/model_zoo/official/recommend/deepfm/export.py index 87c38aafc9..230bafa75b 100644 --- a/model_zoo/official/recommend/deepfm/export.py +++ b/model_zoo/official/recommend/deepfm/export.py @@ -39,6 +39,7 @@ if __name__ == "__main__": model_builder = ModelBuilder(ModelConfig, TrainConfig) _, network = model_builder.get_train_eval_net() + network.set_train(False) load_checkpoint(args.ckpt_file, net=network)