|
|
@@ -60,7 +60,8 @@ def test_net(data_dir, |
|
|
cross_valid_ind=1, |
|
|
cross_valid_ind=1, |
|
|
cfg=None): |
|
|
cfg=None): |
|
|
|
|
|
|
|
|
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False) |
|
|
|
|
|
|
|
|
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'], |
|
|
|
|
|
img_size=cfg['img_size']) |
|
|
labels_list = [] |
|
|
labels_list = [] |
|
|
|
|
|
|
|
|
for data in valid_dataset: |
|
|
for data in valid_dataset: |
|
|
@@ -90,7 +91,7 @@ if __name__ == '__main__': |
|
|
|
|
|
|
|
|
for j in range(len(os.listdir(rst_path))): |
|
|
for j in range(len(os.listdir(rst_path))): |
|
|
file_name = rst_path + "ISBI_test_bs_1_" + str(j) + "_0" + ".bin" |
|
|
file_name = rst_path + "ISBI_test_bs_1_" + str(j) + "_0" + ".bin" |
|
|
output = np.fromfile(file_name, np.float32).reshape(1, 2, 388, 388) |
|
|
|
|
|
|
|
|
output = np.fromfile(file_name, np.float32).reshape(1, 2, 576, 576) |
|
|
label = label_list[j] |
|
|
label = label_list[j] |
|
|
metrics.update(output, label) |
|
|
metrics.update(output, label) |
|
|
|
|
|
|
|
|
|