You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 6.9 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # less required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import math
  16. import numpy as np
  17. from src.config import config
  18. def correct_nifti_head(img):
  19. """
  20. Check nifti object header's format, update the header if needed.
  21. In the updated image pixdim matches the affine.
  22. Args:
  23. img: nifti image object
  24. """
  25. dim = img.header["dim"][0]
  26. if dim >= 5:
  27. return img
  28. pixdim = np.asarray(img.header.get_zooms())[:dim]
  29. norm_affine = np.sqrt(np.sum(np.square(img.affine[:dim, :dim]), 0))
  30. if np.allclose(pixdim, norm_affine):
  31. return img
  32. if hasattr(img, "get_sform"):
  33. return rectify_header_sform_qform(img)
  34. return img
  35. def get_random_patch(dims, patch_size, rand_fn=None):
  36. """
  37. Returns a tuple of slices to define a random patch in an array of shape `dims` with size `patch_size`.
  38. Args:
  39. dims: shape of source array
  40. patch_size: shape of patch size to generate
  41. rand_fn: generate random numbers
  42. Returns:
  43. (tuple of slice): a tuple of slice objects defining the patch
  44. """
  45. rand_int = np.random.randint if rand_fn is None else rand_fn.randint
  46. min_corner = tuple(rand_int(0, ms - ps + 1) if ms > ps else 0 for ms, ps in zip(dims, patch_size))
  47. return tuple(slice(mc, mc + ps) for mc, ps in zip(min_corner, patch_size))
  48. def first(iterable, default=None):
  49. """
  50. Returns the first item in the given iterable or `default` if empty, meaningful mostly with 'for' expressions.
  51. """
  52. for i in iterable:
  53. return i
  54. return default
  55. def _get_scan_interval(image_size, roi_size, num_image_dims, overlap):
  56. """
  57. Compute scan interval according to the image size, roi size and overlap.
  58. Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0,
  59. use 1 instead to make sure sliding window works.
  60. """
  61. if len(image_size) != num_image_dims:
  62. raise ValueError("image different from spatial dims.")
  63. if len(roi_size) != num_image_dims:
  64. raise ValueError("roi size different from spatial dims.")
  65. scan_interval = []
  66. for i in range(num_image_dims):
  67. if roi_size[i] == image_size[i]:
  68. scan_interval.append(int(roi_size[i]))
  69. else:
  70. interval = int(roi_size[i] * (1 - overlap))
  71. scan_interval.append(interval if interval > 0 else 1)
  72. return tuple(scan_interval)
  73. def dense_patch_slices(image_size, patch_size, scan_interval):
  74. """
  75. Enumerate all slices defining ND patches of size `patch_size` from an `image_size` input image.
  76. Args:
  77. image_size: dimensions of image to iterate over
  78. patch_size: size of patches to generate slices
  79. scan_interval: dense patch sampling interval
  80. Returns:
  81. a list of slice objects defining each patch
  82. """
  83. num_spatial_dims = len(image_size)
  84. patch_size = patch_size
  85. scan_num = []
  86. for i in range(num_spatial_dims):
  87. if scan_interval[i] == 0:
  88. scan_num.append(1)
  89. else:
  90. num = int(math.ceil(float(image_size[i]) / scan_interval[i]))
  91. scan_dim = first(d for d in range(num) if d * scan_interval[i] + patch_size[i] >= image_size[i])
  92. scan_num.append(scan_dim + 1 if scan_dim is not None else 1)
  93. starts = []
  94. for dim in range(num_spatial_dims):
  95. dim_starts = []
  96. for idx in range(scan_num[dim]):
  97. start_idx = idx * scan_interval[dim]
  98. start_idx -= max(start_idx + patch_size[dim] - image_size[dim], 0)
  99. dim_starts.append(start_idx)
  100. starts.append(dim_starts)
  101. out = np.asarray([x.flatten() for x in np.meshgrid(*starts, indexing="ij")]).T
  102. return [(slice(None),)*2 + tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out]
  103. def create_sliding_window(image, roi_size, overlap):
  104. num_image_dims = len(image.shape) - 2
  105. if overlap < 0 or overlap >= 1:
  106. raise AssertionError("overlap must be >= 0 and < 1.")
  107. image_size_temp = list(image.shape[2:])
  108. image_size = tuple(max(image_size_temp[i], roi_size[i]) for i in range(num_image_dims))
  109. scan_interval = _get_scan_interval(image_size, roi_size, num_image_dims, overlap)
  110. slices = dense_patch_slices(image_size, roi_size, scan_interval)
  111. windows_sliding = [image[slice] for slice in slices]
  112. return windows_sliding, slices
  113. def one_hot(labels):
  114. N, _, D, H, W = labels.shape
  115. labels = np.reshape(labels, (N, -1))
  116. labels = labels.astype(np.int32)
  117. N, K = labels.shape
  118. one_hot_encoding = np.zeros((N, config['num_classes'], K), dtype=np.float32)
  119. for i in range(N):
  120. for j in range(K):
  121. one_hot_encoding[i, labels[i][j], j] = 1
  122. labels = np.reshape(one_hot_encoding, (N, config['num_classes'], D, H, W))
  123. return labels
  124. def CalculateDice(y_pred, label):
  125. """
  126. Args:
  127. y_pred: predictions. As for classification tasks,
  128. `y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks,
  129. the shape should be [BNHW] or [BNHWD].
  130. label: ground truth, the first dim is batch.
  131. """
  132. y_pred_output = np.expand_dims(np.argmax(y_pred, axis=1), axis=1)
  133. y_pred = one_hot(y_pred_output)
  134. y = one_hot(label)
  135. y_pred, y = ignore_background(y_pred, y)
  136. inter = np.dot(y_pred.flatten(), y.flatten()).astype(np.float64)
  137. union = np.dot(y_pred.flatten(), y_pred.flatten()).astype(np.float64) + np.dot(y.flatten(), \
  138. y.flatten()).astype(np.float64)
  139. single_dice_coeff = 2 * inter / (union + 1e-6)
  140. return single_dice_coeff, y_pred_output
  141. def ignore_background(y_pred, label):
  142. """
  143. This function is used to remove background (the first channel) for `y_pred` and `y`.
  144. Args:
  145. y_pred: predictions. As for classification tasks,
  146. `y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks,
  147. the shape should be [BNHW] or [BNHWD].
  148. label: ground truth, the first dim is batch.
  149. """
  150. label = label[:, 1:] if label.shape[1] > 1 else label
  151. y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred
  152. return y_pred, label