You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 3.7 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # less required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import os
  16. import argparse
  17. import numpy as np
  18. from mindspore import dtype as mstype
  19. from mindspore import Model, context, Tensor
  20. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  21. from src.dataset import create_dataset
  22. from src.unet3d_model import UNet3d
  23. from src.config import config as cfg
  24. from src.utils import create_sliding_window, CalculateDice
  25. device_id = int(os.getenv('DEVICE_ID'))
  26. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
  27. def get_args():
  28. parser = argparse.ArgumentParser(description='Test the UNet3D on images and target masks')
  29. parser.add_argument('--data_url', dest='data_url', type=str, default='', help='image data directory')
  30. parser.add_argument('--seg_url', dest='seg_url', type=str, default='', help='seg data directory')
  31. parser.add_argument('--ckpt_path', dest='ckpt_path', type=str, default='', help='checkpoint path')
  32. return parser.parse_args()
  33. def test_net(data_dir, seg_dir, ckpt_path, config=None):
  34. eval_dataset = create_dataset(data_path=data_dir, seg_path=seg_dir, config=config, is_training=False)
  35. eval_data_size = eval_dataset.get_dataset_size()
  36. print("train dataset length is:", eval_data_size)
  37. network = UNet3d(config=config)
  38. network.set_train(False)
  39. param_dict = load_checkpoint(ckpt_path)
  40. load_param_into_net(network, param_dict)
  41. model = Model(network)
  42. index = 0
  43. total_dice = 0
  44. for batch in eval_dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  45. image = batch["image"]
  46. seg = batch["seg"]
  47. print("current image shape is {}".format(image.shape), flush=True)
  48. sliding_window_list, slice_list = create_sliding_window(image, config.roi_size, config.overlap)
  49. image_size = (config.batch_size, config.num_classes) + image.shape[2:]
  50. output_image = np.zeros(image_size, np.float32)
  51. count_map = np.zeros(image_size, np.float32)
  52. importance_map = np.ones(config.roi_size, np.float32)
  53. for window, slice_ in zip(sliding_window_list, slice_list):
  54. window_image = Tensor(window, mstype.float32)
  55. pred_probs = model.predict(window_image)
  56. output_image[slice_] += pred_probs.asnumpy()
  57. count_map[slice_] += importance_map
  58. output_image = output_image / count_map
  59. dice, _ = CalculateDice(output_image, seg)
  60. print("The {} batch dice is {}".format(index, dice), flush=True)
  61. total_dice += dice
  62. index = index + 1
  63. avg_dice = total_dice / eval_data_size
  64. print("**********************End Eval***************************************")
  65. print("eval average dice is {}".format(avg_dice))
  66. if __name__ == '__main__':
  67. args = get_args()
  68. print("Testing setting:", args)
  69. test_net(data_dir=args.data_url,
  70. seg_dir=args.seg_url,
  71. ckpt_path=args.ckpt_path,
  72. config=cfg)