|
|
|
@@ -30,8 +30,8 @@ from datasets.mnist_add.get_mnist_add import get_mnist_add |
|
|
|
from datasets.hwf.get_hwf import get_hwf |
|
|
|
|
|
|
|
class Params: |
|
|
|
imgH = 45 |
|
|
|
imgW = 45 |
|
|
|
imgH = 28 |
|
|
|
imgW = 28 |
|
|
|
keep_ratio = True |
|
|
|
saveInterval = 10 |
|
|
|
batchSize = 16 |
|
|
|
@@ -45,17 +45,15 @@ def run_test(): |
|
|
|
|
|
|
|
recorder_file_path = f"{result_dir}/1116.pk"# |
|
|
|
|
|
|
|
# kb = add_KB() |
|
|
|
kb = hwf_KB() |
|
|
|
abducer = AbducerBase(kb) |
|
|
|
kb = add_KB() |
|
|
|
# kb = hwf_KB() |
|
|
|
abducer = AbducerBase(kb, 2) |
|
|
|
|
|
|
|
recorder = logger() |
|
|
|
recorder.set_savefile("test.log") |
|
|
|
|
|
|
|
|
|
|
|
# train_X, train_Y, test_X, test_Y = get_mnist_add() |
|
|
|
train_X, train_Y, test_X, test_Y = get_hwf() |
|
|
|
|
|
|
|
train_X, train_Z, train_Y = get_mnist_add(train = True, get_pseudo_label = True) |
|
|
|
test_X, test_Z, test_Y = get_mnist_add(train = False, get_pseudo_label = True) |
|
|
|
|
|
|
|
recorder = plog.ResultRecorder() |
|
|
|
cls = LeNet5(num_classes=len(kb.pseudo_label_list), image_size=(train_X[0][0].shape[1:])) |
|
|
|
@@ -67,7 +65,7 @@ def run_test(): |
|
|
|
base_model = BasicModel(cls, criterion, optimizer, device, Params(), recorder=recorder) |
|
|
|
model = MyModel(base_model, kb.pseudo_label_list) |
|
|
|
|
|
|
|
res = framework.train(model, abducer, train_X, train_Y, sample_num = 10000, verbose = 1) |
|
|
|
res = framework.train(model, abducer, train_X, train_Z, train_Y, sample_num = 10000, verbose = 1) |
|
|
|
print(res) |
|
|
|
|
|
|
|
|
|
|
|
|