diff --git a/framework.py b/framework.py index ed013b7..2a1fdad 100644 --- a/framework.py +++ b/framework.py @@ -68,7 +68,7 @@ def train(model, abducer, train_data, test_data, epochs = 50, sample_num = -1, v # Set default parameters if sample_num == -1: - sample_num = len(X) + sample_num = len(train_X) if verbose < 1: verbose = epochs