|
|
|
@@ -18,7 +18,6 @@ import argparse |
|
|
|
import cv2 |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from src.data_loader import create_dataset, create_cell_nuclei_dataset |
|
|
|
from src.config import cfg_unet |
|
|
|
|
|
|
|
class dice_coeff(): |
|
|
|
@@ -74,25 +73,6 @@ class dice_coeff(): |
|
|
|
raise RuntimeError('Total samples num must not be 0.') |
|
|
|
return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num)) |
|
|
|
|
|
|
|
|
|
|
|
def test_net(data_dir, |
|
|
|
cross_valid_ind=1, |
|
|
|
cfg=None): |
|
|
|
|
|
|
|
if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei": |
|
|
|
valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False, |
|
|
|
eval_resize=cfg["eval_resize"], split=0.8) |
|
|
|
else: |
|
|
|
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'], |
|
|
|
img_size=cfg['img_size']) |
|
|
|
labels_list = [] |
|
|
|
|
|
|
|
for data in valid_dataset: |
|
|
|
labels_list.append(data[1].asnumpy()) |
|
|
|
|
|
|
|
return labels_list |
|
|
|
|
|
|
|
|
|
|
|
def get_args(): |
|
|
|
parser = argparse.ArgumentParser(description='Test the UNet on images and target masks', |
|
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
|
|
|
@@ -105,24 +85,31 @@ def get_args(): |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
args = get_args() |
|
|
|
|
|
|
|
label_list = test_net(data_dir=args.data_url, cross_valid_ind=cfg_unet['cross_valid_ind'], cfg=cfg_unet) |
|
|
|
rst_path = args.rst_path |
|
|
|
metrics = dice_coeff() |
|
|
|
|
|
|
|
if 'dataset' in cfg_unet and cfg_unet['dataset'] == "Cell_nuclei": |
|
|
|
img_size = tuple(cfg_unet['img_size']) |
|
|
|
for i, bin_name in enumerate(os.listdir('./preprocess_Result/')): |
|
|
|
bin_name_softmax = bin_name.replace(".png", "") + "_0.bin" |
|
|
|
bin_name_argmax = bin_name.replace(".png", "") + "_1.bin" |
|
|
|
f = bin_name.replace(".png", "") |
|
|
|
bin_name_softmax = f + "_0.bin" |
|
|
|
bin_name_argmax = f + "_1.bin" |
|
|
|
file_name_sof = rst_path + bin_name_softmax |
|
|
|
file_name_arg = rst_path + bin_name_argmax |
|
|
|
softmax_out = np.fromfile(file_name_sof, np.float32).reshape(1, 96, 96, 2) |
|
|
|
argmax_out = np.fromfile(file_name_arg, np.float32).reshape(1, 96, 96) |
|
|
|
label = label_list[i] |
|
|
|
mask = cv2.imread(os.path.join(args.data_url, f, "mask.png"), cv2.IMREAD_GRAYSCALE) |
|
|
|
mask = cv2.resize(mask, img_size) |
|
|
|
mask = mask.astype(np.float32) / 255 |
|
|
|
mask = (mask > 0.5).astype(np.int) |
|
|
|
mask = (np.arange(2) == mask[..., None]).astype(int) |
|
|
|
mask = mask.transpose(2, 0, 1).astype(np.float32) |
|
|
|
label = mask.reshape(1, 2, 96, 96) |
|
|
|
metrics.update((softmax_out, argmax_out), label) |
|
|
|
else: |
|
|
|
label_list = np.load('label.npy') |
|
|
|
for j in range(len(os.listdir('./preprocess_Result/'))): |
|
|
|
file_name_sof = rst_path + "ISBI_test_bs_1_" + str(j) + "_0" + ".bin" |
|
|
|
file_name_arg = rst_path + "ISBI_test_bs_1_" + str(j) + "_1" + ".bin" |
|
|
|
|