diff --git a/gpu_mnist_example/train.py b/gpu_mnist_example/train.py index 9db7665..e71eaf2 100644 --- a/gpu_mnist_example/train.py +++ b/gpu_mnist_example/train.py @@ -109,7 +109,7 @@ if __name__ == '__main__': start_epoch = 0 print('无保存模型,将从头开始训练!') - for epoch in range(start_epoch+1, epochs): + for epoch in range(start_epoch+1, epochs+1): train(model, train_loader, epoch) test(model, test_loader, test_dataset) # 将模型保存到c2net_context.output_path