You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

preprocess.py 4.2 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """unet 310 infer preprocess dataset"""
  16. import argparse
  17. import os
  18. import numpy as np
  19. import cv2
  20. from src.data_loader import create_dataset
  21. from src.config import cfg_unet
  22. def preprocess_dataset(data_dir, result_path, cross_valid_ind=1, cfg=None):
  23. _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'],
  24. img_size=cfg['img_size'])
  25. labels_list = []
  26. for i, data in enumerate(valid_dataset):
  27. file_name = "ISBI_test_bs_1_" + str(i) + ".bin"
  28. file_path = result_path + file_name
  29. data[0].asnumpy().tofile(file_path)
  30. labels_list.append(data[1].asnumpy())
  31. np.save("./label.npy", labels_list)
  32. class CellNucleiDataset:
  33. """
  34. Cell nuclei dataset preprocess class.
  35. """
  36. def __init__(self, data_dir, repeat, result_path, is_train=False, split=0.8):
  37. self.data_dir = data_dir
  38. self.img_ids = sorted(next(os.walk(self.data_dir))[1])
  39. self.train_ids = self.img_ids[:int(len(self.img_ids) * split)] * repeat
  40. np.random.shuffle(self.train_ids)
  41. self.val_ids = self.img_ids[int(len(self.img_ids) * split):]
  42. self.is_train = is_train
  43. self.result_path = result_path
  44. self._preprocess_dataset()
  45. def _preprocess_dataset(self):
  46. for img_id in self.val_ids:
  47. path = os.path.join(self.data_dir, img_id)
  48. img = cv2.imread(os.path.join(path, "images", img_id + ".png"))
  49. if len(img.shape) == 2:
  50. img = np.expand_dims(img, axis=-1)
  51. img = np.concatenate([img, img, img], axis=-1)
  52. mask = []
  53. for mask_file in next(os.walk(os.path.join(path, "masks")))[2]:
  54. mask_ = cv2.imread(os.path.join(path, "masks", mask_file), cv2.IMREAD_GRAYSCALE)
  55. mask.append(mask_)
  56. mask = np.max(mask, axis=0)
  57. cv2.imwrite(os.path.join(self.result_path, img_id + ".png"), img)
  58. def _read_img_mask(self, img_id):
  59. path = os.path.join(self.data_dir, img_id)
  60. img = cv2.imread(os.path.join(path, "image.png"))
  61. mask = cv2.imread(os.path.join(path, "mask.png"), cv2.IMREAD_GRAYSCALE)
  62. return img, mask
  63. def __getitem__(self, index):
  64. if self.is_train:
  65. return self._read_img_mask(self.train_ids[index])
  66. return self._read_img_mask(self.val_ids[index])
  67. @property
  68. def column_names(self):
  69. column_names = ['image', 'mask']
  70. return column_names
  71. def __len__(self):
  72. if self.is_train:
  73. return len(self.train_ids)
  74. return len(self.val_ids)
  75. def get_args():
  76. parser = argparse.ArgumentParser(description='Preprocess the UNet dataset ',
  77. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  78. parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/',
  79. help='data directory')
  80. parser.add_argument('-p', '--result_path', dest='result_path', type=str, default='./preprocess_Result/',
  81. help='result path')
  82. return parser.parse_args()
  83. if __name__ == '__main__':
  84. args = get_args()
  85. if 'dataset' in cfg_unet and cfg_unet['dataset'] == "Cell_nuclei":
  86. cell_dataset = CellNucleiDataset(args.data_url, 1, args.result_path, False, 0.8)
  87. else:
  88. preprocess_dataset(data_dir=args.data_url, cross_valid_ind=cfg_unet['cross_valid_ind'], cfg=cfg_unet,
  89. result_path=args.result_path)