| @@ -16,41 +16,30 @@ | |||
| import os | |||
| import argparse | |||
| import logging | |||
| import cv2 | |||
| import numpy as np | |||
| import mindspore | |||
| import mindspore.nn as nn | |||
| import mindspore.ops.operations as F | |||
| from mindspore import context, Model | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.nn.loss.loss import _Loss | |||
| from src.data_loader import create_dataset, create_cell_nuclei_dataset | |||
| from src.unet_medical import UNetMedical | |||
| from src.unet_nested import NestedUNet, UNet | |||
| from src.config import cfg_unet | |||
| from scipy.special import softmax | |||
| from src.utils import UnetEval | |||
| device_id = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) | |||
| class CrossEntropyWithLogits(_Loss): | |||
| class TempLoss(nn.Cell): | |||
| """A temp loss cell.""" | |||
| def __init__(self): | |||
| super(CrossEntropyWithLogits, self).__init__() | |||
| self.transpose_fn = F.Transpose() | |||
| self.reshape_fn = F.Reshape() | |||
| self.softmax_cross_entropy_loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| self.cast = F.Cast() | |||
| super(TempLoss, self).__init__() | |||
| self.identity = F.identity() | |||
| def construct(self, logits, label): | |||
| # NCHW->NHWC | |||
| logits = self.transpose_fn(logits, (0, 2, 3, 1)) | |||
| logits = self.cast(logits, mindspore.float32) | |||
| label = self.transpose_fn(label, (0, 2, 3, 1)) | |||
| loss = self.reduce_mean(self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, 2)), | |||
| self.reshape_fn(label, (-1, 2)))) | |||
| return self.get_loss(loss) | |||
| return self.identity(logits) | |||
| class dice_coeff(nn.Metric): | |||
| @@ -64,16 +53,35 @@ class dice_coeff(nn.Metric): | |||
| def update(self, *inputs): | |||
| if len(inputs) != 2: | |||
| raise ValueError('Mean dice coefficient need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) | |||
| y_pred = self._convert_data(inputs[0]) | |||
| raise ValueError('Need 2 inputs ((y_softmax, y_argmax), y), but got {}'.format(len(inputs))) | |||
| y = self._convert_data(inputs[1]) | |||
| self._samples_num += y.shape[0] | |||
| y_pred = y_pred.transpose(0, 2, 3, 1) | |||
| y = y.transpose(0, 2, 3, 1) | |||
| y_pred = softmax(y_pred, axis=3) | |||
| b, h, w, c = y.shape | |||
| if b != 1: | |||
| raise ValueError('Batch size should be 1 when in evaluation.') | |||
| y = y.reshape((h, w, c)) | |||
| if cfg_unet["eval_activate"].lower() == "softmax": | |||
| y_softmax = np.squeeze(self._convert_data(inputs[0][0]), axis=0) | |||
| if cfg_unet["eval_resize"]: | |||
| y_pred = [] | |||
| for i in range(cfg_unet["num_classes"]): | |||
| y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, i] * 255), (w, h)) / 255) | |||
| y_pred = np.stack(y_pred, axis=-1) | |||
| else: | |||
| y_pred = y_softmax | |||
| elif cfg_unet["eval_activate"].lower() == "argmax": | |||
| y_argmax = np.squeeze(self._convert_data(inputs[0][1]), axis=0) | |||
| y_pred = [] | |||
| for i in range(cfg_unet["num_classes"]): | |||
| if cfg_unet["eval_resize"]: | |||
| y_pred.append(cv2.resize(np.uint8(y_argmax == i), (w, h), interpolation=cv2.INTER_NEAREST)) | |||
| else: | |||
| y_pred.append(np.float32(y_argmax == i)) | |||
| y_pred = np.stack(y_pred, axis=-1) | |||
| else: | |||
| raise ValueError('config eval_activate should be softmax or argmax.') | |||
| y_pred = y_pred.astype(np.float32) | |||
| inter = np.dot(y_pred.flatten(), y.flatten()) | |||
| union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten()) | |||
| @@ -104,14 +112,14 @@ def test_net(data_dir, | |||
| raise ValueError("Unsupported model: {}".format(cfg['model'])) | |||
| param_dict = load_checkpoint(ckpt_path) | |||
| load_param_into_net(net, param_dict) | |||
| criterion = CrossEntropyWithLogits() | |||
| net = UnetEval(net) | |||
| if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei": | |||
| valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False, split=0.8) | |||
| valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False, | |||
| eval_resize=cfg["eval_resize"], split=0.8) | |||
| else: | |||
| _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, | |||
| do_crop=cfg['crop'], img_size=cfg['img_size']) | |||
| model = Model(net, loss_fn=criterion, metrics={"dice_coeff": dice_coeff()}) | |||
| model = Model(net, loss_fn=TempLoss(), metrics={"dice_coeff": dice_coeff()}) | |||
| print("============== Starting Evaluating ============") | |||
| eval_score = model.eval(valid_dataset, dataset_sink_mode=False)["dice_coeff"] | |||
| @@ -33,7 +33,9 @@ cfg_unet_medical = { | |||
| 'resume': False, | |||
| 'resume_ckpt': './', | |||
| 'transfer_training': False, | |||
| 'filter_weight': ['outc.weight', 'outc.bias'] | |||
| 'filter_weight': ['outc.weight', 'outc.bias'], | |||
| 'eval_activate': 'Softmax', | |||
| 'eval_resize': False | |||
| } | |||
| cfg_unet_nested = { | |||
| @@ -59,7 +61,9 @@ cfg_unet_nested = { | |||
| 'resume': False, | |||
| 'resume_ckpt': './', | |||
| 'transfer_training': False, | |||
| 'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'] | |||
| 'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'], | |||
| 'eval_activate': 'Softmax', | |||
| 'eval_resize': False | |||
| } | |||
| cfg_unet_nested_cell = { | |||
| @@ -86,7 +90,9 @@ cfg_unet_nested_cell = { | |||
| 'resume': False, | |||
| 'resume_ckpt': './', | |||
| 'transfer_training': False, | |||
| 'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'] | |||
| 'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'], | |||
| 'eval_activate': 'Softmax', | |||
| 'eval_resize': False | |||
| } | |||
| cfg_unet_simple = { | |||
| @@ -109,7 +115,12 @@ cfg_unet_simple = { | |||
| 'resume': False, | |||
| 'resume_ckpt': './', | |||
| 'transfer_training': False, | |||
| 'filter_weight': ["final.weight"] | |||
| 'filter_weight': ["final.weight"], | |||
| 'eval_activate': 'Softmax', | |||
| 'eval_resize': False | |||
| } | |||
| cfg_unet = cfg_unet_medical | |||
| if not ('dataset' in cfg_unet and cfg_unet['dataset'] == 'Cell_nuclei') and cfg_unet['eval_resize']: | |||
| print("ISBI dataset not support resize to original image size when in evaluation.") | |||
| cfg_unet['eval_resize'] = False | |||
| @@ -216,7 +216,7 @@ class CellNucleiDataset: | |||
| return len(self.train_ids) | |||
| return len(self.val_ids) | |||
| def preprocess_img_mask(img, mask, img_size, augment=False): | |||
| def preprocess_img_mask(img, mask, img_size, augment=False, eval_resize=False): | |||
| """ | |||
| Preprocess for cell nuclei dataset. | |||
| Random crop and flip images and masks when augment is True. | |||
| @@ -236,7 +236,8 @@ def preprocess_img_mask(img, mask, img_size, augment=False): | |||
| mask = cv2.flip(mask, flip_code) | |||
| else: | |||
| img = cv2.resize(img, img_size) | |||
| mask = cv2.resize(mask, img_size) | |||
| if not eval_resize: | |||
| mask = cv2.resize(mask, img_size) | |||
| img = (img.astype(np.float32) - 127.5) / 127.5 | |||
| img = img.transpose(2, 0, 1) | |||
| mask = mask.astype(np.float32) / 255 | |||
| @@ -245,7 +246,7 @@ def preprocess_img_mask(img, mask, img_size, augment=False): | |||
| mask = mask.transpose(2, 0, 1).astype(np.float32) | |||
| return img, mask | |||
| def create_cell_nuclei_dataset(data_dir, img_size, repeat, batch_size, is_train=False, augment=False, | |||
| def create_cell_nuclei_dataset(data_dir, img_size, repeat, batch_size, is_train=False, augment=False, eval_resize=False, | |||
| split=0.8, rank=0, group_size=1, python_multiprocessing=True, num_parallel_workers=8): | |||
| """ | |||
| Get generator dataset for cell nuclei dataset. | |||
| @@ -253,7 +254,8 @@ def create_cell_nuclei_dataset(data_dir, img_size, repeat, batch_size, is_train= | |||
| cell_dataset = CellNucleiDataset(data_dir, repeat, is_train, split) | |||
| sampler = ds.DistributedSampler(group_size, rank, shuffle=is_train) | |||
| dataset = ds.GeneratorDataset(cell_dataset, cell_dataset.column_names, sampler=sampler) | |||
| compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, tuple(img_size), augment and is_train)) | |||
| compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, tuple(img_size), augment and is_train, | |||
| eval_resize)) | |||
| dataset = dataset.map(operations=compose_map_func, input_columns=cell_dataset.column_names, | |||
| output_columns=cell_dataset.column_names, column_order=cell_dataset.column_names, | |||
| python_multiprocessing=python_multiprocessing, | |||
| @@ -16,9 +16,29 @@ | |||
| import time | |||
| import numpy as np | |||
| from PIL import Image | |||
| from mindspore import nn | |||
| from mindspore.ops import operations as ops | |||
| from mindspore.train.callback import Callback | |||
| from mindspore.common.tensor import Tensor | |||
| class UnetEval(nn.Cell): | |||
| """ | |||
| Add Unet evaluation activation. | |||
| """ | |||
| def __init__(self, net): | |||
| super(UnetEval, self).__init__() | |||
| self.net = net | |||
| self.transpose = ops.Transpose() | |||
| self.softmax = ops.Softmax(axis=-1) | |||
| self.argmax = ops.Argmax(axis=-1) | |||
| def construct(self, x): | |||
| out = self.net(x) | |||
| out = self.transpose(out, (0, 2, 3, 1)) | |||
| softmax_out = self.softmax(out) | |||
| argmax_out = self.argmax(out) | |||
| return (softmax_out, argmax_out) | |||
| class StepLossTimeMonitor(Callback): | |||
| def __init__(self, batch_size, per_print_times=1): | |||