diff --git a/datasets/mnist_add/get_mnist_add.py b/datasets/mnist_add/get_mnist_add.py index 183f9a5..eb75e47 100644 --- a/datasets/mnist_add/get_mnist_add.py +++ b/datasets/mnist_add/get_mnist_add.py @@ -9,7 +9,7 @@ def get_data(file, img_dataset): with open(file) as f: for line in f: line = line.strip().split(' ') - X.append((img_dataset[int(line[0])][0], img_dataset[int(line[1])][0])) + X.append([img_dataset[int(line[0])][0], img_dataset[int(line[1])][0]]) Y.append(int(line[2])) return X, Y @@ -26,3 +26,4 @@ if __name__ == "__main__": train_X, train_Y, test_X, test_Y = get_mnist_add() print(len(train_X), len(test_X)) print(train_X[0][0].shape, train_X[0][1].shape, train_Y[0]) +