From a8f33df3a306a4dbd98b6a290d9f23e3b75e28de Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Fri, 18 Nov 2022 16:03:54 +0800 Subject: [PATCH] Update example.py --- example.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/example.py b/example.py index 65789e2..14d2108 100644 --- a/example.py +++ b/example.py @@ -30,8 +30,8 @@ from datasets.mnist_add.get_mnist_add import get_mnist_add from datasets.hwf.get_hwf import get_hwf class Params: - imgH = 45 - imgW = 45 + imgH = 28 + imgW = 28 keep_ratio = True saveInterval = 10 batchSize = 16 @@ -45,17 +45,15 @@ def run_test(): recorder_file_path = f"{result_dir}/1116.pk"# - # kb = add_KB() - kb = hwf_KB() - abducer = AbducerBase(kb) + kb = add_KB() + # kb = hwf_KB() + abducer = AbducerBase(kb, 2) recorder = logger() recorder.set_savefile("test.log") - - # train_X, train_Y, test_X, test_Y = get_mnist_add() - train_X, train_Y, test_X, test_Y = get_hwf() - + train_X, train_Z, train_Y = get_mnist_add(train = True, get_pseudo_label = True) + test_X, test_Z, test_Y = get_mnist_add(train = False, get_pseudo_label = True) recorder = plog.ResultRecorder() cls = LeNet5(num_classes=len(kb.pseudo_label_list), image_size=(train_X[0][0].shape[1:])) @@ -67,7 +65,7 @@ def run_test(): base_model = BasicModel(cls, criterion, optimizer, device, Params(), recorder=recorder) model = MyModel(base_model, kb.pseudo_label_list) - res = framework.train(model, abducer, train_X, train_Y, sample_num = 10000, verbose = 1) + res = framework.train(model, abducer, train_X, train_Z, train_Y, sample_num = 10000, verbose = 1) print(res)