| @@ -52,7 +52,7 @@ if __name__ == '__main__': | |||||
| #获取数据集路径 | #获取数据集路径 | ||||
| MnistDataset_torch_path = c2net_context.dataset_path+"/"+"MnistDataset_torch" | MnistDataset_torch_path = c2net_context.dataset_path+"/"+"MnistDataset_torch" | ||||
| #获取预训练模型路径 | #获取预训练模型路径 | ||||
| Torch_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"Torch_MNIST_Example_Model" | |||||
| Torch_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"GCU_MNIST_Example_Model" | |||||
| #获取输出路径 | #获取输出路径 | ||||
| output_path = c2net_context.output_path | output_path = c2net_context.output_path | ||||
| # load DPU envs-xx.sh | # load DPU envs-xx.sh | ||||
| @@ -88,8 +88,8 @@ if __name__ == '__main__': | |||||
| print('epoch_size is:{}'.format(epochs)) | print('epoch_size is:{}'.format(epochs)) | ||||
| # 如果有保存的模型,则加载模型,并在其基础上继续训练 | # 如果有保存的模型,则加载模型,并在其基础上继续训练 | ||||
| if os.path.exists(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1.pkl")): | |||||
| checkpoint = torch.load(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1.pkl")) | |||||
| if os.path.exists(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1_0.81.pkl")): | |||||
| checkpoint = torch.load(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1_0.81.pkl")) | |||||
| model.load_state_dict(checkpoint['model']) | model.load_state_dict(checkpoint['model']) | ||||
| optimizer.load_state_dict(checkpoint['optimizer']) | optimizer.load_state_dict(checkpoint['optimizer']) | ||||
| start_epoch = checkpoint['epoch'] | start_epoch = checkpoint['epoch'] | ||||