|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # less required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- import math
- import numpy as np
- from src.config import config
-
- def correct_nifti_head(img):
- """
- Check nifti object header's format, update the header if needed.
- In the updated image pixdim matches the affine.
-
- Args:
- img: nifti image object
- """
- dim = img.header["dim"][0]
- if dim >= 5:
- return img
- pixdim = np.asarray(img.header.get_zooms())[:dim]
- norm_affine = np.sqrt(np.sum(np.square(img.affine[:dim, :dim]), 0))
- if np.allclose(pixdim, norm_affine):
- return img
- if hasattr(img, "get_sform"):
- return rectify_header_sform_qform(img)
- return img
-
- def get_random_patch(dims, patch_size, rand_fn=None):
- """
- Returns a tuple of slices to define a random patch in an array of shape `dims` with size `patch_size`.
-
- Args:
- dims: shape of source array
- patch_size: shape of patch size to generate
- rand_fn: generate random numbers
-
- Returns:
- (tuple of slice): a tuple of slice objects defining the patch
- """
- rand_int = np.random.randint if rand_fn is None else rand_fn.randint
- min_corner = tuple(rand_int(0, ms - ps + 1) if ms > ps else 0 for ms, ps in zip(dims, patch_size))
- return tuple(slice(mc, mc + ps) for mc, ps in zip(min_corner, patch_size))
-
-
- def first(iterable, default=None):
- """
- Returns the first item in the given iterable or `default` if empty, meaningful mostly with 'for' expressions.
- """
- for i in iterable:
- return i
- return default
-
- def _get_scan_interval(image_size, roi_size, num_image_dims, overlap):
- """
- Compute scan interval according to the image size, roi size and overlap.
- Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0,
- use 1 instead to make sure sliding window works.
- """
- if len(image_size) != num_image_dims:
- raise ValueError("image different from spatial dims.")
- if len(roi_size) != num_image_dims:
- raise ValueError("roi size different from spatial dims.")
-
- scan_interval = []
- for i in range(num_image_dims):
- if roi_size[i] == image_size[i]:
- scan_interval.append(int(roi_size[i]))
- else:
- interval = int(roi_size[i] * (1 - overlap))
- scan_interval.append(interval if interval > 0 else 1)
- return tuple(scan_interval)
-
- def dense_patch_slices(image_size, patch_size, scan_interval):
- """
- Enumerate all slices defining ND patches of size `patch_size` from an `image_size` input image.
-
- Args:
- image_size: dimensions of image to iterate over
- patch_size: size of patches to generate slices
- scan_interval: dense patch sampling interval
-
- Returns:
- a list of slice objects defining each patch
- """
- num_spatial_dims = len(image_size)
- patch_size = patch_size
- scan_num = []
- for i in range(num_spatial_dims):
- if scan_interval[i] == 0:
- scan_num.append(1)
- else:
- num = int(math.ceil(float(image_size[i]) / scan_interval[i]))
- scan_dim = first(d for d in range(num) if d * scan_interval[i] + patch_size[i] >= image_size[i])
- scan_num.append(scan_dim + 1 if scan_dim is not None else 1)
- starts = []
- for dim in range(num_spatial_dims):
- dim_starts = []
- for idx in range(scan_num[dim]):
- start_idx = idx * scan_interval[dim]
- start_idx -= max(start_idx + patch_size[dim] - image_size[dim], 0)
- dim_starts.append(start_idx)
- starts.append(dim_starts)
- out = np.asarray([x.flatten() for x in np.meshgrid(*starts, indexing="ij")]).T
- return [(slice(None),)*2 + tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out]
-
- def create_sliding_window(image, roi_size, overlap):
- num_image_dims = len(image.shape) - 2
- if overlap < 0 or overlap >= 1:
- raise AssertionError("overlap must be >= 0 and < 1.")
- image_size_temp = list(image.shape[2:])
- image_size = tuple(max(image_size_temp[i], roi_size[i]) for i in range(num_image_dims))
-
- scan_interval = _get_scan_interval(image_size, roi_size, num_image_dims, overlap)
- slices = dense_patch_slices(image_size, roi_size, scan_interval)
- windows_sliding = [image[slice] for slice in slices]
- return windows_sliding, slices
-
- def one_hot(labels):
- N, _, D, H, W = labels.shape
- labels = np.reshape(labels, (N, -1))
- labels = labels.astype(np.int32)
- N, K = labels.shape
- one_hot_encoding = np.zeros((N, config['num_classes'], K), dtype=np.float32)
- for i in range(N):
- for j in range(K):
- one_hot_encoding[i, labels[i][j], j] = 1
- labels = np.reshape(one_hot_encoding, (N, config['num_classes'], D, H, W))
- return labels
-
- def CalculateDice(y_pred, label):
- """
- Args:
- y_pred: predictions. As for classification tasks,
- `y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks,
- the shape should be [BNHW] or [BNHWD].
- label: ground truth, the first dim is batch.
- """
- y_pred_output = np.expand_dims(np.argmax(y_pred, axis=1), axis=1)
- y_pred = one_hot(y_pred_output)
- y = one_hot(label)
- y_pred, y = ignore_background(y_pred, y)
- inter = np.dot(y_pred.flatten(), y.flatten()).astype(np.float64)
- union = np.dot(y_pred.flatten(), y_pred.flatten()).astype(np.float64) + np.dot(y.flatten(), \
- y.flatten()).astype(np.float64)
- single_dice_coeff = 2 * inter / (union + 1e-6)
- return single_dice_coeff, y_pred_output
-
- def ignore_background(y_pred, label):
- """
- This function is used to remove background (the first channel) for `y_pred` and `y`.
- Args:
- y_pred: predictions. As for classification tasks,
- `y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks,
- the shape should be [BNHW] or [BNHWD].
- label: ground truth, the first dim is batch.
- """
- label = label[:, 1:] if label.shape[1] > 1 else label
- y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred
- return y_pred, label
|