From f9812e319958b99ff2185fe88f4797119bac7ccf Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Wed, 16 Nov 2022 13:22:54 +0800 Subject: [PATCH] Update get_mnist_add.py --- datasets/mnist_add/get_mnist_add.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datasets/mnist_add/get_mnist_add.py b/datasets/mnist_add/get_mnist_add.py index 183f9a5..eb75e47 100644 --- a/datasets/mnist_add/get_mnist_add.py +++ b/datasets/mnist_add/get_mnist_add.py @@ -9,7 +9,7 @@ def get_data(file, img_dataset): with open(file) as f: for line in f: line = line.strip().split(' ') - X.append((img_dataset[int(line[0])][0], img_dataset[int(line[1])][0])) + X.append([img_dataset[int(line[0])][0], img_dataset[int(line[1])][0]]) Y.append(int(line[2])) return X, Y @@ -26,3 +26,4 @@ if __name__ == "__main__": train_X, train_Y, test_X, test_Y = get_mnist_add() print(len(train_X), len(test_X)) print(train_X[0][0].shape, train_X[0][1].shape, train_Y[0]) +