diff --git a/datasets/mnist_add/get_mnist_add.py b/datasets/mnist_add/get_mnist_add.py index eb75e47..fcada50 100644 --- a/datasets/mnist_add/get_mnist_add.py +++ b/datasets/mnist_add/get_mnist_add.py @@ -15,10 +15,10 @@ def get_data(file, img_dataset): def get_mnist_add(): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))]) - img_dataset = torchvision.datasets.MNIST(root='./', train=True, download=True, transform=transform) + img_dataset = torchvision.datasets.MNIST(root='./datasets/mnist_add/', train=True, download=True, transform=transform) - train_X, train_Y = get_data('./train_data.txt', img_dataset) - test_X, test_Y = get_data('./test_data.txt', img_dataset) + train_X, train_Y = get_data('./datasets/mnist_add/train_data.txt', img_dataset) + test_X, test_Y = get_data('./datasets/mnist_add/test_data.txt', img_dataset) return train_X, train_Y, test_X, test_Y