From fb31009f12df68f900eaa77d560da8520cbb029c Mon Sep 17 00:00:00 2001 From: lihongkang <[lihongkang1@huawei.com]> Date: Sat, 6 Mar 2021 10:25:59 +0800 Subject: [PATCH] update code for unet 310 infer --- model_zoo/official/cv/unet/postprocess.py | 5 +++-- model_zoo/official/cv/unet/preprocess.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/model_zoo/official/cv/unet/postprocess.py b/model_zoo/official/cv/unet/postprocess.py index 7c31081ef6..bc179bd39d 100644 --- a/model_zoo/official/cv/unet/postprocess.py +++ b/model_zoo/official/cv/unet/postprocess.py @@ -60,7 +60,8 @@ def test_net(data_dir, cross_valid_ind=1, cfg=None): - _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False) + _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'], + img_size=cfg['img_size']) labels_list = [] for data in valid_dataset: @@ -90,7 +91,7 @@ if __name__ == '__main__': for j in range(len(os.listdir(rst_path))): file_name = rst_path + "ISBI_test_bs_1_" + str(j) + "_0" + ".bin" - output = np.fromfile(file_name, np.float32).reshape(1, 2, 388, 388) + output = np.fromfile(file_name, np.float32).reshape(1, 2, 576, 576) label = label_list[j] metrics.update(output, label) diff --git a/model_zoo/official/cv/unet/preprocess.py b/model_zoo/official/cv/unet/preprocess.py index 4a96916cc7..9a87b5bd07 100644 --- a/model_zoo/official/cv/unet/preprocess.py +++ b/model_zoo/official/cv/unet/preprocess.py @@ -20,7 +20,8 @@ from src.config import cfg_unet def preprocess_dataset(data_dir, result_path, cross_valid_ind=1, cfg=None): - _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False) + _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'], + img_size=cfg['img_size']) for i, data in enumerate(valid_dataset): file_name = "ISBI_test_bs_1_" + str(i) + ".bin"