|
|
|
@@ -15,10 +15,10 @@ def get_data(file, img_dataset): |
|
|
|
|
|
|
|
def get_mnist_add(): |
|
|
|
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))]) |
|
|
|
img_dataset = torchvision.datasets.MNIST(root='./', train=True, download=True, transform=transform) |
|
|
|
img_dataset = torchvision.datasets.MNIST(root='./datasets/mnist_add/', train=True, download=True, transform=transform) |
|
|
|
|
|
|
|
train_X, train_Y = get_data('./train_data.txt', img_dataset) |
|
|
|
test_X, test_Y = get_data('./test_data.txt', img_dataset) |
|
|
|
train_X, train_Y = get_data('./datasets/mnist_add/train_data.txt', img_dataset) |
|
|
|
test_X, test_Y = get_data('./datasets/mnist_add/test_data.txt', img_dataset) |
|
|
|
|
|
|
|
return train_X, train_Y, test_X, test_Y |
|
|
|
|
|
|
|
|