Browse Source

update code for unet 310 infer

tags/v1.2.0-rc1
lihongkang 4 years ago
parent
commit
fb31009f12
2 changed files with 5 additions and 3 deletions
  1. +3
    -2
      model_zoo/official/cv/unet/postprocess.py
  2. +2
    -1
      model_zoo/official/cv/unet/preprocess.py

+ 3
- 2
model_zoo/official/cv/unet/postprocess.py View File

@@ -60,7 +60,8 @@ def test_net(data_dir,
cross_valid_ind=1,
cfg=None):

_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False)
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'],
img_size=cfg['img_size'])
labels_list = []

for data in valid_dataset:
@@ -90,7 +91,7 @@ if __name__ == '__main__':

for j in range(len(os.listdir(rst_path))):
file_name = rst_path + "ISBI_test_bs_1_" + str(j) + "_0" + ".bin"
output = np.fromfile(file_name, np.float32).reshape(1, 2, 388, 388)
output = np.fromfile(file_name, np.float32).reshape(1, 2, 576, 576)
label = label_list[j]
metrics.update(output, label)



+ 2
- 1
model_zoo/official/cv/unet/preprocess.py View File

@@ -20,7 +20,8 @@ from src.config import cfg_unet

def preprocess_dataset(data_dir, result_path, cross_valid_ind=1, cfg=None):

_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False)
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'],
img_size=cfg['img_size'])

for i, data in enumerate(valid_dataset):
file_name = "ISBI_test_bs_1_" + str(i) + ".bin"


Loading…
Cancel
Save