|
|
|
@@ -53,6 +53,7 @@ def test(model, test_loader, data_length): |
|
|
|
test_loss /= (i+1) |
|
|
|
|
|
|
|
# 结果写入输出文件夹 |
|
|
|
print('accuracy: {:.2f}'.format(correct / data_length)) |
|
|
|
filename = 'result.txt' |
|
|
|
file_path = os.path.join('/tmp/output', filename) |
|
|
|
with open(file_path, 'w') as file: |
|
|
|
@@ -74,9 +75,9 @@ if __name__ == '__main__': |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
batch_size = args.batch_size |
|
|
|
epochs = args.epoch_size |
|
|
|
test_dataset = mnist.MNIST(root=mnist_example_test2_model_djts_path + "/test", train=False, transform=ToTensor(),download=False) |
|
|
|
test_dataset = mnist.MNIST(root=MnistDataset_torch + "/test", train=False, transform=ToTensor(),download=False) |
|
|
|
test_loader = DataLoader(test_dataset, batch_size=batch_size) |
|
|
|
model = Model().to(device) |
|
|
|
checkpoint = torch.load(mnist_example_test2_model_djts_path + "/mnist_epoch1_0.73.pkl") |
|
|
|
checkpoint = torch.load(mnist_example_test2_model_djts_path + "/mnist_epoch1.pkl") |
|
|
|
model.load_state_dict(checkpoint['model']) |
|
|
|
test(model,test_loader,len(test_dataset)) |