|
|
|
@@ -9,7 +9,7 @@ def get_data(file, img_dataset): |
|
|
|
with open(file) as f: |
|
|
|
for line in f: |
|
|
|
line = line.strip().split(' ') |
|
|
|
X.append((img_dataset[int(line[0])][0], img_dataset[int(line[1])][0])) |
|
|
|
X.append([img_dataset[int(line[0])][0], img_dataset[int(line[1])][0]]) |
|
|
|
Y.append(int(line[2])) |
|
|
|
return X, Y |
|
|
|
|
|
|
|
@@ -26,3 +26,4 @@ if __name__ == "__main__": |
|
|
|
train_X, train_Y, test_X, test_Y = get_mnist_add() |
|
|
|
print(len(train_X), len(test_X)) |
|
|
|
print(train_X[0][0].shape, train_X[0][1].shape, train_Y[0]) |
|
|
|
|