From 02957cb8bf02a26440ca4b867072795a60a00e2e Mon Sep 17 00:00:00 2001 From: liuzx Date: Mon, 1 Jul 2024 11:12:41 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20'gcu=5Fmnist=5Fexample/tra?= =?UTF-8?q?in.py'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gcu_mnist_example/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gcu_mnist_example/train.py b/gcu_mnist_example/train.py index 44766a0..1aa91e6 100644 --- a/gcu_mnist_example/train.py +++ b/gcu_mnist_example/train.py @@ -52,7 +52,7 @@ if __name__ == '__main__': #获取数据集路径 MnistDataset_torch_path = c2net_context.dataset_path+"/"+"MnistDataset_torch" #获取预训练模型路径 - Torch_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"Torch_MNIST_Example_Model" + Torch_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"GCU_MNIST_Example_Model" #获取输出路径 output_path = c2net_context.output_path # load DPU envs-xx.sh @@ -88,8 +88,8 @@ if __name__ == '__main__': print('epoch_size is:{}'.format(epochs)) # 如果有保存的模型,则加载模型,并在其基础上继续训练 - if os.path.exists(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1.pkl")): - checkpoint = torch.load(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1.pkl")) + if os.path.exists(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1_0.81.pkl")): + checkpoint = torch.load(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1_0.81.pkl")) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch']