|
|
|
@@ -52,7 +52,7 @@ if __name__ == '__main__': |
|
|
|
#获取数据集路径 |
|
|
|
MnistDataset_torch_path = c2net_context.dataset_path+"/"+"MnistDataset_torch" |
|
|
|
#获取预训练模型路径 |
|
|
|
Torch_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"Torch_MNIST_Example_Model" |
|
|
|
Torch_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"GCU_MNIST_Example_Model" |
|
|
|
#获取输出路径 |
|
|
|
output_path = c2net_context.output_path |
|
|
|
# load DPU envs-xx.sh |
|
|
|
@@ -88,8 +88,8 @@ if __name__ == '__main__': |
|
|
|
print('epoch_size is:{}'.format(epochs)) |
|
|
|
|
|
|
|
# 如果有保存的模型,则加载模型,并在其基础上继续训练 |
|
|
|
if os.path.exists(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1.pkl")): |
|
|
|
checkpoint = torch.load(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1.pkl")) |
|
|
|
if os.path.exists(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1_0.81.pkl")): |
|
|
|
checkpoint = torch.load(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1_0.81.pkl")) |
|
|
|
model.load_state_dict(checkpoint['model']) |
|
|
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
|
|
start_epoch = checkpoint['epoch'] |
|
|
|
|