Browse Source

更新 'gcu_mnist_example/train.py'

pull/29/head
liuzx 1 year ago
parent
commit
02957cb8bf
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      gcu_mnist_example/train.py

+ 3
- 3
gcu_mnist_example/train.py View File

@@ -52,7 +52,7 @@ if __name__ == '__main__':
#获取数据集路径
MnistDataset_torch_path = c2net_context.dataset_path+"/"+"MnistDataset_torch"
#获取预训练模型路径
Torch_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"Torch_MNIST_Example_Model"
Torch_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"GCU_MNIST_Example_Model"
#获取输出路径
output_path = c2net_context.output_path
# load DPU envs-xx.sh
@@ -88,8 +88,8 @@ if __name__ == '__main__':
print('epoch_size is:{}'.format(epochs))

# 如果有保存的模型,则加载模型,并在其基础上继续训练
if os.path.exists(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1.pkl")):
checkpoint = torch.load(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1.pkl"))
if os.path.exists(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1_0.81.pkl")):
checkpoint = torch.load(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1_0.81.pkl"))
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']


Loading…
Cancel
Save