diff --git a/model_zoo/official/cv/unet/eval.py b/model_zoo/official/cv/unet/eval.py index f3a2cce6f4..3755386344 100644 --- a/model_zoo/official/cv/unet/eval.py +++ b/model_zoo/official/cv/unet/eval.py @@ -24,7 +24,7 @@ from mindspore import context, Model from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.nn.loss.loss import _Loss -from src.data_loader import create_dataset +from src.data_loader import create_dataset, create_cell_nuclei_dataset from src.unet_medical import UNetMedical from src.unet_nested import NestedUNet, UNet from src.config import cfg_unet @@ -59,6 +59,7 @@ class dice_coeff(nn.Metric): self.clear() def clear(self): self._dice_coeff_sum = 0 + self._iou_sum = 0 self._samples_num = 0 def update(self, *inputs): @@ -77,13 +78,15 @@ class dice_coeff(nn.Metric): union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten()) single_dice_coeff = 2*float(inter)/float(union+1e-6) - print("single dice coeff is:", single_dice_coeff) + single_iou = single_dice_coeff / (2 - single_dice_coeff) + print("single dice coeff is: {}, IOU is: {}".format(single_dice_coeff, single_iou)) self._dice_coeff_sum += single_dice_coeff + self._iou_sum += single_iou def eval(self): if self._samples_num == 0: raise RuntimeError('Total samples num must not be 0.') - return self._dice_coeff_sum / float(self._samples_num) + return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num)) def test_net(data_dir, @@ -93,7 +96,8 @@ def test_net(data_dir, if cfg['model'] == 'unet_medical': net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) elif cfg['model'] == 'unet_nested': - net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) + net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'], + use_bn=cfg['use_bn'], use_ds=False) elif cfg['model'] == 'unet_simple': net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) else: @@ -102,13 +106,17 @@ def test_net(data_dir, load_param_into_net(net, param_dict) criterion = CrossEntropyWithLogits() - _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, - do_crop=cfg['crop'], img_size=cfg['img_size']) + if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei": + valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False, split=0.8) + else: + _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, + do_crop=cfg['crop'], img_size=cfg['img_size']) model = Model(net, loss_fn=criterion, metrics={"dice_coeff": dice_coeff()}) print("============== Starting Evaluating ============") - dice_score = model.eval(valid_dataset, dataset_sink_mode=False) - print("============== Cross valid dice coeff is:", dice_score) + eval_score = model.eval(valid_dataset, dataset_sink_mode=False)["dice_coeff"] + print("============== Cross valid dice coeff is:", eval_score[0]) + print("============== Cross valid IOU is:", eval_score[1]) def get_args(): diff --git a/model_zoo/official/cv/unet/export.py b/model_zoo/official/cv/unet/export.py index 17eb773f40..b26d8e3d12 100644 --- a/model_zoo/official/cv/unet/export.py +++ b/model_zoo/official/cv/unet/export.py @@ -42,7 +42,8 @@ if __name__ == "__main__": if cfg['model'] == 'unet_medical': net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) elif cfg['model'] == 'unet_nested': - net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) + net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'], + use_bn=cfg['use_bn'], use_ds=False) elif cfg['model'] == 'unet_simple': net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) else: diff --git a/model_zoo/official/cv/unet/src/config.py b/model_zoo/official/cv/unet/src/config.py index 2e34a7bd6e..af5aa359fa 100644 --- a/model_zoo/official/cv/unet/src/config.py +++ b/model_zoo/official/cv/unet/src/config.py @@ -50,6 +50,34 @@ cfg_unet_nested = { 'weight_decay': 0.0005, 'loss_scale': 1024.0, 'FixedLossScaleManager': 1024.0, + 'use_bn': True, + 'use_ds': True, + 'use_deconv': True, + + 'resume': False, + 'resume_ckpt': './', +} + +cfg_unet_nested_cell = { + 'model': 'unet_nested', + 'dataset': 'Cell_nuclei', + 'crop': None, + 'img_size': [96, 96], + 'lr': 3e-4, + 'epochs': 200, + 'distribute_epochs': 1600, + 'batchsize': 16, + 'cross_valid_ind': 1, + 'num_classes': 2, + 'num_channels': 3, + + 'keep_checkpoint_max': 10, + 'weight_decay': 0.0005, + 'loss_scale': 1024.0, + 'FixedLossScaleManager': 1024.0, + 'use_bn': True, + 'use_ds': True, + 'use_deconv': True, 'resume': False, 'resume_ckpt': './', diff --git a/model_zoo/official/cv/unet/src/data_loader.py b/model_zoo/official/cv/unet/src/data_loader.py index 32f056628e..a49e92106b 100644 --- a/model_zoo/official/cv/unet/src/data_loader.py +++ b/model_zoo/official/cv/unet/src/data_loader.py @@ -15,6 +15,7 @@ import os from collections import deque +import cv2 import numpy as np from PIL import Image, ImageSequence import mindspore.dataset as ds @@ -23,7 +24,6 @@ from mindspore.dataset.vision.utils import Inter from mindspore.communication.management import get_rank, get_group_size - def _load_multipage_tiff(path): """Load tiff images containing many images in the channel dimension""" return np.array([np.array(p) for p in ImageSequence.Iterator(Image.open(path))]) @@ -164,3 +164,100 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro valid_ds = valid_ds.batch(batch_size=1, drop_remainder=True) return train_ds, valid_ds + +class CellNucleiDataset: + """ + Cell nuclei dataset preprocess class. + """ + def __init__(self, data_dir, repeat, is_train=False, split=0.8): + self.data_dir = data_dir + self.img_ids = sorted(next(os.walk(self.data_dir))[1]) + self.train_ids = self.img_ids[:int(len(self.img_ids) * split)] * repeat + np.random.shuffle(self.train_ids) + self.val_ids = self.img_ids[int(len(self.img_ids) * split):] + self.is_train = is_train + self._preprocess_dataset() + + def _preprocess_dataset(self): + for img_id in self.img_ids: + path = os.path.join(self.data_dir, img_id) + if (not os.path.exists(os.path.join(path, "image.png"))) or \ + (not os.path.exists(os.path.join(path, "mask.png"))): + img = cv2.imread(os.path.join(path, "images", img_id + ".png")) + if len(img.shape) == 2: + img = np.expand_dims(img, axis=-1) + img = np.concatenate([img, img, img], axis=-1) + mask = [] + for mask_file in next(os.walk(os.path.join(path, "masks")))[2]: + mask_ = cv2.imread(os.path.join(path, "masks", mask_file), cv2.IMREAD_GRAYSCALE) + mask.append(mask_) + mask = np.max(mask, axis=0) + cv2.imwrite(os.path.join(path, "image.png"), img) + cv2.imwrite(os.path.join(path, "mask.png"), mask) + + def _read_img_mask(self, img_id): + path = os.path.join(self.data_dir, img_id) + img = cv2.imread(os.path.join(path, "image.png")) + mask = cv2.imread(os.path.join(path, "mask.png"), cv2.IMREAD_GRAYSCALE) + return img, mask + + def __getitem__(self, index): + if self.is_train: + return self._read_img_mask(self.train_ids[index]) + return self._read_img_mask(self.val_ids[index]) + + @property + def column_names(self): + column_names = ['image', 'mask'] + return column_names + + def __len__(self): + if self.is_train: + return len(self.train_ids) + return len(self.val_ids) + +def preprocess_img_mask(img, mask, img_size, augment=False): + """ + Preprocess for cell nuclei dataset. + Random crop and flip images and masks when augment is True. + """ + if augment: + img_size_w = int(np.random.randint(img_size[0], img_size[0] * 1.5, 1)) + img_size_h = int(np.random.randint(img_size[1], img_size[1] * 1.5, 1)) + img = cv2.resize(img, (img_size_w, img_size_h)) + mask = cv2.resize(mask, (img_size_w, img_size_h)) + dw = int(np.random.randint(0, img_size_w - img_size[0] + 1, 1)) + dh = int(np.random.randint(0, img_size_h - img_size[1] + 1, 1)) + img = img[dh:dh+img_size[1], dw:dw+img_size[0], :] + mask = mask[dh:dh+img_size[1], dw:dw+img_size[0]] + if np.random.random() > 0.5: + flip_code = int(np.random.randint(-1, 2, 1)) + img = cv2.flip(img, flip_code) + mask = cv2.flip(mask, flip_code) + else: + img = cv2.resize(img, img_size) + mask = cv2.resize(mask, img_size) + img = (img.astype(np.float32) - 127.5) / 127.5 + img = img.transpose(2, 0, 1) + mask = mask.astype(np.float32) / 255 + mask = (mask > 0.5).astype(np.int) + mask = (np.arange(2) == mask[..., None]).astype(int) + mask = mask.transpose(2, 0, 1).astype(np.float32) + return img, mask + +def create_cell_nuclei_dataset(data_dir, img_size, repeat, batch_size, is_train=False, augment=False, + split=0.8, rank=0, group_size=1, python_multiprocessing=True, num_parallel_workers=8): + """ + Get generator dataset for cell nuclei dataset. + """ + cell_dataset = CellNucleiDataset(data_dir, repeat, is_train, split) + sampler = ds.DistributedSampler(group_size, rank, shuffle=is_train) + dataset = ds.GeneratorDataset(cell_dataset, cell_dataset.column_names, sampler=sampler) + compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, tuple(img_size), augment and is_train)) + dataset = dataset.map(operations=compose_map_func, input_columns=cell_dataset.column_names, + output_columns=cell_dataset.column_names, column_order=cell_dataset.column_names, + python_multiprocessing=python_multiprocessing, + num_parallel_workers=num_parallel_workers) + dataset = dataset.batch(batch_size, drop_remainder=is_train) + dataset = dataset.repeat(1) + return dataset diff --git a/model_zoo/official/cv/unet/src/loss.py b/model_zoo/official/cv/unet/src/loss.py index 26793bd873..9d84088ce2 100644 --- a/model_zoo/official/cv/unet/src/loss.py +++ b/model_zoo/official/cv/unet/src/loss.py @@ -36,3 +36,15 @@ class CrossEntropyWithLogits(_Loss): loss = self.reduce_mean( self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, 2)), self.reshape_fn(label, (-1, 2)))) return self.get_loss(loss) + +class MultiCrossEntropyWithLogits(nn.Cell): + def __init__(self): + super(MultiCrossEntropyWithLogits, self).__init__() + self.loss = CrossEntropyWithLogits() + self.squeeze = F.Squeeze() + + def construct(self, logits, label): + total_loss = 0 + for i in range(len(logits)): + total_loss += self.loss(self.squeeze(logits[i:i+1]), label) + return total_loss diff --git a/model_zoo/official/cv/unet/src/unet_nested/unet_model.py b/model_zoo/official/cv/unet/src/unet_nested/unet_model.py index bbfba469db..3522373e81 100644 --- a/model_zoo/official/cv/unet/src/unet_nested/unet_model.py +++ b/model_zoo/official/cv/unet/src/unet_nested/unet_model.py @@ -16,6 +16,7 @@ # Model of UnetPlusPlus import mindspore.nn as nn +import mindspore.ops as P from .unet_parts import UnetConv2d, UnetUp @@ -63,6 +64,7 @@ class NestedUNet(nn.Cell): self.final2 = nn.Conv2d(filters[0], n_class, 1) self.final3 = nn.Conv2d(filters[0], n_class, 1) self.final4 = nn.Conv2d(filters[0], n_class, 1) + self.stack = P.Stack(axis=0) def construct(self, inputs): x00 = self.conv00(inputs) # channel = filters[0] @@ -86,13 +88,12 @@ class NestedUNet(nn.Cell): x04 = self.up_concat04(x13, x00, x01, x02, x03) # channel = filters[0] final1 = self.final1(x01) - final2 = self.final1(x02) - final3 = self.final1(x03) - final4 = self.final1(x04) - - final = (final1 + final2 + final3 + final4) / 4.0 + final2 = self.final2(x02) + final3 = self.final3(x03) + final4 = self.final4(x04) if self.use_ds: + final = self.stack((final1, final2, final3, final4)) return final return final4 diff --git a/model_zoo/official/cv/unet/src/utils.py b/model_zoo/official/cv/unet/src/utils.py index d6bb9d653b..cda763efd8 100644 --- a/model_zoo/official/cv/unet/src/utils.py +++ b/model_zoo/official/cv/unet/src/utils.py @@ -51,9 +51,23 @@ class StepLossTimeMonitor(Callback): if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( cb_params.cur_epoch_num, cur_step_in_epoch)) + self.losses.append(loss) if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: # TEST print("step: %s, loss is %s, fps is %s" % (cur_step_in_epoch, loss, step_fps), flush=True) + def epoch_begin(self, run_context): + self.epoch_start = time.time() + self.losses = [] + + def epoch_end(self, run_context): + cb_params = run_context.original_args() + epoch_cost = time.time() - self.epoch_start + step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 + step_fps = self.batch_size * 1.0 * step_in_epoch / epoch_cost + print("epoch: {:3d}, avg loss:{:.4f}, total cost: {:.3f} s, per step fps:{:5.3f}".format( + cb_params.cur_epoch_num, np.mean(self.losses), epoch_cost, step_fps), flush=True) + + def mask_to_image(mask): return Image.fromarray((mask * 255).astype(np.uint8)) diff --git a/model_zoo/official/cv/unet/train.py b/model_zoo/official/cv/unet/train.py index 37615f7003..7b1032c367 100644 --- a/model_zoo/official/cv/unet/train.py +++ b/model_zoo/official/cv/unet/train.py @@ -21,15 +21,15 @@ import ast import mindspore import mindspore.nn as nn from mindspore import Model, context -from mindspore.communication.management import init, get_group_size +from mindspore.communication.management import init, get_group_size, get_rank from mindspore.train.callback import CheckpointConfig, ModelCheckpoint from mindspore.context import ParallelMode from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.unet_medical import UNetMedical from src.unet_nested import NestedUNet, UNet -from src.data_loader import create_dataset -from src.loss import CrossEntropyWithLogits +from src.data_loader import create_dataset, create_cell_nuclei_dataset +from src.loss import CrossEntropyWithLogits, MultiCrossEntropyWithLogits from src.utils import StepLossTimeMonitor from src.config import cfg_unet @@ -46,10 +46,12 @@ def train_net(data_dir, run_distribute=False, cfg=None): - + rank = 0 + group_size = 1 if run_distribute: init() group_size = get_group_size() + rank = get_rank() parallel_mode = ParallelMode.DATA_PARALLEL context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, @@ -58,7 +60,8 @@ def train_net(data_dir, if cfg['model'] == 'unet_medical': net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) elif cfg['model'] == 'unet_nested': - net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) + net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'], + use_bn=cfg['use_bn'], use_ds=cfg['use_ds']) elif cfg['model'] == 'unet_simple': net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) else: @@ -68,14 +71,28 @@ def train_net(data_dir, param_dict = load_checkpoint(cfg['resume_ckpt']) load_param_into_net(net, param_dict) - criterion = CrossEntropyWithLogits() - train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute, cfg["crop"], - cfg['img_size']) + if 'use_ds' in cfg and cfg['use_ds']: + criterion = MultiCrossEntropyWithLogits() + else: + criterion = CrossEntropyWithLogits() + if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei": + repeat = 10 + dataset_sink_mode = True + per_print_times = 0 + train_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], repeat, batch_size, + is_train=True, augment=True, split=0.8, rank=rank, + group_size=group_size) + else: + repeat = epochs + dataset_sink_mode = False + per_print_times = 1 + train_dataset, _ = create_dataset(data_dir, repeat, batch_size, True, cross_valid_ind, run_distribute, + cfg["crop"], cfg['img_size']) train_data_size = train_dataset.get_dataset_size() print("dataset length is:", train_data_size) ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size, keep_checkpoint_max=cfg['keep_checkpoint_max']) - ckpoint_cb = ModelCheckpoint(prefix='ckpt_unet_medical_adam', + ckpoint_cb = ModelCheckpoint(prefix='ckpt_{}_adam'.format(cfg['model']), directory='./ckpt_{}/'.format(device_id), config=ckpt_config) @@ -87,13 +104,11 @@ def train_net(data_dir, model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer, amp_level="O3") print("============== Starting Training ==============") - model.train(1, train_dataset, callbacks=[StepLossTimeMonitor(batch_size=batch_size), ckpoint_cb], - dataset_sink_mode=False) + callbacks = [StepLossTimeMonitor(batch_size=batch_size, per_print_times=per_print_times), ckpoint_cb] + model.train(int(epochs / repeat), train_dataset, callbacks=callbacks, dataset_sink_mode=dataset_sink_mode) print("============== End Training ==============") - - def get_args(): parser = argparse.ArgumentParser(description='Train the UNet on images and target masks', formatter_class=argparse.ArgumentDefaultsHelpFormatter)