Browse Source

Update example.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
a8f33df3a3
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 10 deletions
  1. +8
    -10
      example.py

+ 8
- 10
example.py View File

@@ -30,8 +30,8 @@ from datasets.mnist_add.get_mnist_add import get_mnist_add
from datasets.hwf.get_hwf import get_hwf

class Params:
imgH = 45
imgW = 45
imgH = 28
imgW = 28
keep_ratio = True
saveInterval = 10
batchSize = 16
@@ -45,17 +45,15 @@ def run_test():

recorder_file_path = f"{result_dir}/1116.pk"#

# kb = add_KB()
kb = hwf_KB()
abducer = AbducerBase(kb)
kb = add_KB()
# kb = hwf_KB()
abducer = AbducerBase(kb, 2)

recorder = logger()
recorder.set_savefile("test.log")


# train_X, train_Y, test_X, test_Y = get_mnist_add()
train_X, train_Y, test_X, test_Y = get_hwf()

train_X, train_Z, train_Y = get_mnist_add(train = True, get_pseudo_label = True)
test_X, test_Z, test_Y = get_mnist_add(train = False, get_pseudo_label = True)

recorder = plog.ResultRecorder()
cls = LeNet5(num_classes=len(kb.pseudo_label_list), image_size=(train_X[0][0].shape[1:]))
@@ -67,7 +65,7 @@ def run_test():
base_model = BasicModel(cls, criterion, optimizer, device, Params(), recorder=recorder)
model = MyModel(base_model, kb.pseudo_label_list)

res = framework.train(model, abducer, train_X, train_Y, sample_num = 10000, verbose = 1)
res = framework.train(model, abducer, train_X, train_Z, train_Y, sample_num = 10000, verbose = 1)
print(res)



Loading…
Cancel
Save