Browse Source

add ceil nuclei dataset for unet++

tags/v1.2.0-rc1
zhaoting 4 years ago
parent
commit
8710c953ca
8 changed files with 204 additions and 28 deletions
  1. +16
    -8
      model_zoo/official/cv/unet/eval.py
  2. +2
    -1
      model_zoo/official/cv/unet/export.py
  3. +28
    -0
      model_zoo/official/cv/unet/src/config.py
  4. +98
    -1
      model_zoo/official/cv/unet/src/data_loader.py
  5. +12
    -0
      model_zoo/official/cv/unet/src/loss.py
  6. +6
    -5
      model_zoo/official/cv/unet/src/unet_nested/unet_model.py
  7. +14
    -0
      model_zoo/official/cv/unet/src/utils.py
  8. +28
    -13
      model_zoo/official/cv/unet/train.py

+ 16
- 8
model_zoo/official/cv/unet/eval.py View File

@@ -24,7 +24,7 @@ from mindspore import context, Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.loss.loss import _Loss from mindspore.nn.loss.loss import _Loss


from src.data_loader import create_dataset
from src.data_loader import create_dataset, create_cell_nuclei_dataset
from src.unet_medical import UNetMedical from src.unet_medical import UNetMedical
from src.unet_nested import NestedUNet, UNet from src.unet_nested import NestedUNet, UNet
from src.config import cfg_unet from src.config import cfg_unet
@@ -59,6 +59,7 @@ class dice_coeff(nn.Metric):
self.clear() self.clear()
def clear(self): def clear(self):
self._dice_coeff_sum = 0 self._dice_coeff_sum = 0
self._iou_sum = 0
self._samples_num = 0 self._samples_num = 0


def update(self, *inputs): def update(self, *inputs):
@@ -77,13 +78,15 @@ class dice_coeff(nn.Metric):
union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten()) union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())


single_dice_coeff = 2*float(inter)/float(union+1e-6) single_dice_coeff = 2*float(inter)/float(union+1e-6)
print("single dice coeff is:", single_dice_coeff)
single_iou = single_dice_coeff / (2 - single_dice_coeff)
print("single dice coeff is: {}, IOU is: {}".format(single_dice_coeff, single_iou))
self._dice_coeff_sum += single_dice_coeff self._dice_coeff_sum += single_dice_coeff
self._iou_sum += single_iou


def eval(self): def eval(self):
if self._samples_num == 0: if self._samples_num == 0:
raise RuntimeError('Total samples num must not be 0.') raise RuntimeError('Total samples num must not be 0.')
return self._dice_coeff_sum / float(self._samples_num)
return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num))




def test_net(data_dir, def test_net(data_dir,
@@ -93,7 +96,8 @@ def test_net(data_dir,
if cfg['model'] == 'unet_medical': if cfg['model'] == 'unet_medical':
net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
elif cfg['model'] == 'unet_nested': elif cfg['model'] == 'unet_nested':
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'],
use_bn=cfg['use_bn'], use_ds=False)
elif cfg['model'] == 'unet_simple': elif cfg['model'] == 'unet_simple':
net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
else: else:
@@ -102,13 +106,17 @@ def test_net(data_dir,
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)


criterion = CrossEntropyWithLogits() criterion = CrossEntropyWithLogits()
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False,
do_crop=cfg['crop'], img_size=cfg['img_size'])
if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei":
valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False, split=0.8)
else:
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False,
do_crop=cfg['crop'], img_size=cfg['img_size'])
model = Model(net, loss_fn=criterion, metrics={"dice_coeff": dice_coeff()}) model = Model(net, loss_fn=criterion, metrics={"dice_coeff": dice_coeff()})


print("============== Starting Evaluating ============") print("============== Starting Evaluating ============")
dice_score = model.eval(valid_dataset, dataset_sink_mode=False)
print("============== Cross valid dice coeff is:", dice_score)
eval_score = model.eval(valid_dataset, dataset_sink_mode=False)["dice_coeff"]
print("============== Cross valid dice coeff is:", eval_score[0])
print("============== Cross valid IOU is:", eval_score[1])




def get_args(): def get_args():


+ 2
- 1
model_zoo/official/cv/unet/export.py View File

@@ -42,7 +42,8 @@ if __name__ == "__main__":
if cfg['model'] == 'unet_medical': if cfg['model'] == 'unet_medical':
net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
elif cfg['model'] == 'unet_nested': elif cfg['model'] == 'unet_nested':
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'],
use_bn=cfg['use_bn'], use_ds=False)
elif cfg['model'] == 'unet_simple': elif cfg['model'] == 'unet_simple':
net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
else: else:


+ 28
- 0
model_zoo/official/cv/unet/src/config.py View File

@@ -50,6 +50,34 @@ cfg_unet_nested = {
'weight_decay': 0.0005, 'weight_decay': 0.0005,
'loss_scale': 1024.0, 'loss_scale': 1024.0,
'FixedLossScaleManager': 1024.0, 'FixedLossScaleManager': 1024.0,
'use_bn': True,
'use_ds': True,
'use_deconv': True,
'resume': False,
'resume_ckpt': './',
}
cfg_unet_nested_cell = {
'model': 'unet_nested',
'dataset': 'Cell_nuclei',
'crop': None,
'img_size': [96, 96],
'lr': 3e-4,
'epochs': 200,
'distribute_epochs': 1600,
'batchsize': 16,
'cross_valid_ind': 1,
'num_classes': 2,
'num_channels': 3,
'keep_checkpoint_max': 10,
'weight_decay': 0.0005,
'loss_scale': 1024.0,
'FixedLossScaleManager': 1024.0,
'use_bn': True,
'use_ds': True,
'use_deconv': True,
'resume': False, 'resume': False,
'resume_ckpt': './', 'resume_ckpt': './',


+ 98
- 1
model_zoo/official/cv/unet/src/data_loader.py View File

@@ -15,6 +15,7 @@


import os import os
from collections import deque from collections import deque
import cv2
import numpy as np import numpy as np
from PIL import Image, ImageSequence from PIL import Image, ImageSequence
import mindspore.dataset as ds import mindspore.dataset as ds
@@ -23,7 +24,6 @@ from mindspore.dataset.vision.utils import Inter
from mindspore.communication.management import get_rank, get_group_size from mindspore.communication.management import get_rank, get_group_size





def _load_multipage_tiff(path): def _load_multipage_tiff(path):
"""Load tiff images containing many images in the channel dimension""" """Load tiff images containing many images in the channel dimension"""
return np.array([np.array(p) for p in ImageSequence.Iterator(Image.open(path))]) return np.array([np.array(p) for p in ImageSequence.Iterator(Image.open(path))])
@@ -164,3 +164,100 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro
valid_ds = valid_ds.batch(batch_size=1, drop_remainder=True) valid_ds = valid_ds.batch(batch_size=1, drop_remainder=True)


return train_ds, valid_ds return train_ds, valid_ds

class CellNucleiDataset:
"""
Cell nuclei dataset preprocess class.
"""
def __init__(self, data_dir, repeat, is_train=False, split=0.8):
self.data_dir = data_dir
self.img_ids = sorted(next(os.walk(self.data_dir))[1])
self.train_ids = self.img_ids[:int(len(self.img_ids) * split)] * repeat
np.random.shuffle(self.train_ids)
self.val_ids = self.img_ids[int(len(self.img_ids) * split):]
self.is_train = is_train
self._preprocess_dataset()

def _preprocess_dataset(self):
for img_id in self.img_ids:
path = os.path.join(self.data_dir, img_id)
if (not os.path.exists(os.path.join(path, "image.png"))) or \
(not os.path.exists(os.path.join(path, "mask.png"))):
img = cv2.imread(os.path.join(path, "images", img_id + ".png"))
if len(img.shape) == 2:
img = np.expand_dims(img, axis=-1)
img = np.concatenate([img, img, img], axis=-1)
mask = []
for mask_file in next(os.walk(os.path.join(path, "masks")))[2]:
mask_ = cv2.imread(os.path.join(path, "masks", mask_file), cv2.IMREAD_GRAYSCALE)
mask.append(mask_)
mask = np.max(mask, axis=0)
cv2.imwrite(os.path.join(path, "image.png"), img)
cv2.imwrite(os.path.join(path, "mask.png"), mask)

def _read_img_mask(self, img_id):
path = os.path.join(self.data_dir, img_id)
img = cv2.imread(os.path.join(path, "image.png"))
mask = cv2.imread(os.path.join(path, "mask.png"), cv2.IMREAD_GRAYSCALE)
return img, mask

def __getitem__(self, index):
if self.is_train:
return self._read_img_mask(self.train_ids[index])
return self._read_img_mask(self.val_ids[index])

@property
def column_names(self):
column_names = ['image', 'mask']
return column_names

def __len__(self):
if self.is_train:
return len(self.train_ids)
return len(self.val_ids)

def preprocess_img_mask(img, mask, img_size, augment=False):
"""
Preprocess for cell nuclei dataset.
Random crop and flip images and masks when augment is True.
"""
if augment:
img_size_w = int(np.random.randint(img_size[0], img_size[0] * 1.5, 1))
img_size_h = int(np.random.randint(img_size[1], img_size[1] * 1.5, 1))
img = cv2.resize(img, (img_size_w, img_size_h))
mask = cv2.resize(mask, (img_size_w, img_size_h))
dw = int(np.random.randint(0, img_size_w - img_size[0] + 1, 1))
dh = int(np.random.randint(0, img_size_h - img_size[1] + 1, 1))
img = img[dh:dh+img_size[1], dw:dw+img_size[0], :]
mask = mask[dh:dh+img_size[1], dw:dw+img_size[0]]
if np.random.random() > 0.5:
flip_code = int(np.random.randint(-1, 2, 1))
img = cv2.flip(img, flip_code)
mask = cv2.flip(mask, flip_code)
else:
img = cv2.resize(img, img_size)
mask = cv2.resize(mask, img_size)
img = (img.astype(np.float32) - 127.5) / 127.5
img = img.transpose(2, 0, 1)
mask = mask.astype(np.float32) / 255
mask = (mask > 0.5).astype(np.int)
mask = (np.arange(2) == mask[..., None]).astype(int)
mask = mask.transpose(2, 0, 1).astype(np.float32)
return img, mask

def create_cell_nuclei_dataset(data_dir, img_size, repeat, batch_size, is_train=False, augment=False,
split=0.8, rank=0, group_size=1, python_multiprocessing=True, num_parallel_workers=8):
"""
Get generator dataset for cell nuclei dataset.
"""
cell_dataset = CellNucleiDataset(data_dir, repeat, is_train, split)
sampler = ds.DistributedSampler(group_size, rank, shuffle=is_train)
dataset = ds.GeneratorDataset(cell_dataset, cell_dataset.column_names, sampler=sampler)
compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, tuple(img_size), augment and is_train))
dataset = dataset.map(operations=compose_map_func, input_columns=cell_dataset.column_names,
output_columns=cell_dataset.column_names, column_order=cell_dataset.column_names,
python_multiprocessing=python_multiprocessing,
num_parallel_workers=num_parallel_workers)
dataset = dataset.batch(batch_size, drop_remainder=is_train)
dataset = dataset.repeat(1)
return dataset

+ 12
- 0
model_zoo/official/cv/unet/src/loss.py View File

@@ -36,3 +36,15 @@ class CrossEntropyWithLogits(_Loss):
loss = self.reduce_mean( loss = self.reduce_mean(
self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, 2)), self.reshape_fn(label, (-1, 2)))) self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, 2)), self.reshape_fn(label, (-1, 2))))
return self.get_loss(loss) return self.get_loss(loss)
class MultiCrossEntropyWithLogits(nn.Cell):
def __init__(self):
super(MultiCrossEntropyWithLogits, self).__init__()
self.loss = CrossEntropyWithLogits()
self.squeeze = F.Squeeze()
def construct(self, logits, label):
total_loss = 0
for i in range(len(logits)):
total_loss += self.loss(self.squeeze(logits[i:i+1]), label)
return total_loss

+ 6
- 5
model_zoo/official/cv/unet/src/unet_nested/unet_model.py View File

@@ -16,6 +16,7 @@
# Model of UnetPlusPlus # Model of UnetPlusPlus


import mindspore.nn as nn import mindspore.nn as nn
import mindspore.ops as P
from .unet_parts import UnetConv2d, UnetUp from .unet_parts import UnetConv2d, UnetUp




@@ -63,6 +64,7 @@ class NestedUNet(nn.Cell):
self.final2 = nn.Conv2d(filters[0], n_class, 1) self.final2 = nn.Conv2d(filters[0], n_class, 1)
self.final3 = nn.Conv2d(filters[0], n_class, 1) self.final3 = nn.Conv2d(filters[0], n_class, 1)
self.final4 = nn.Conv2d(filters[0], n_class, 1) self.final4 = nn.Conv2d(filters[0], n_class, 1)
self.stack = P.Stack(axis=0)


def construct(self, inputs): def construct(self, inputs):
x00 = self.conv00(inputs) # channel = filters[0] x00 = self.conv00(inputs) # channel = filters[0]
@@ -86,13 +88,12 @@ class NestedUNet(nn.Cell):
x04 = self.up_concat04(x13, x00, x01, x02, x03) # channel = filters[0] x04 = self.up_concat04(x13, x00, x01, x02, x03) # channel = filters[0]


final1 = self.final1(x01) final1 = self.final1(x01)
final2 = self.final1(x02)
final3 = self.final1(x03)
final4 = self.final1(x04)

final = (final1 + final2 + final3 + final4) / 4.0
final2 = self.final2(x02)
final3 = self.final3(x03)
final4 = self.final4(x04)


if self.use_ds: if self.use_ds:
final = self.stack((final1, final2, final3, final4))
return final return final
return final4 return final4




+ 14
- 0
model_zoo/official/cv/unet/src/utils.py View File

@@ -51,9 +51,23 @@ class StepLossTimeMonitor(Callback):
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
cb_params.cur_epoch_num, cur_step_in_epoch)) cb_params.cur_epoch_num, cur_step_in_epoch))
self.losses.append(loss)
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
# TEST # TEST
print("step: %s, loss is %s, fps is %s" % (cur_step_in_epoch, loss, step_fps), flush=True) print("step: %s, loss is %s, fps is %s" % (cur_step_in_epoch, loss, step_fps), flush=True)
def epoch_begin(self, run_context):
self.epoch_start = time.time()
self.losses = []
def epoch_end(self, run_context):
cb_params = run_context.original_args()
epoch_cost = time.time() - self.epoch_start
step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
step_fps = self.batch_size * 1.0 * step_in_epoch / epoch_cost
print("epoch: {:3d}, avg loss:{:.4f}, total cost: {:.3f} s, per step fps:{:5.3f}".format(
cb_params.cur_epoch_num, np.mean(self.losses), epoch_cost, step_fps), flush=True)
def mask_to_image(mask): def mask_to_image(mask):
return Image.fromarray((mask * 255).astype(np.uint8)) return Image.fromarray((mask * 255).astype(np.uint8))

+ 28
- 13
model_zoo/official/cv/unet/train.py View File

@@ -21,15 +21,15 @@ import ast
import mindspore import mindspore
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Model, context from mindspore import Model, context
from mindspore.communication.management import init, get_group_size
from mindspore.communication.management import init, get_group_size, get_rank
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net


from src.unet_medical import UNetMedical from src.unet_medical import UNetMedical
from src.unet_nested import NestedUNet, UNet from src.unet_nested import NestedUNet, UNet
from src.data_loader import create_dataset
from src.loss import CrossEntropyWithLogits
from src.data_loader import create_dataset, create_cell_nuclei_dataset
from src.loss import CrossEntropyWithLogits, MultiCrossEntropyWithLogits
from src.utils import StepLossTimeMonitor from src.utils import StepLossTimeMonitor
from src.config import cfg_unet from src.config import cfg_unet


@@ -46,10 +46,12 @@ def train_net(data_dir,
run_distribute=False, run_distribute=False,
cfg=None): cfg=None):



rank = 0
group_size = 1
if run_distribute: if run_distribute:
init() init()
group_size = get_group_size() group_size = get_group_size()
rank = get_rank()
parallel_mode = ParallelMode.DATA_PARALLEL parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, context.set_auto_parallel_context(parallel_mode=parallel_mode,
device_num=group_size, device_num=group_size,
@@ -58,7 +60,8 @@ def train_net(data_dir,
if cfg['model'] == 'unet_medical': if cfg['model'] == 'unet_medical':
net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
elif cfg['model'] == 'unet_nested': elif cfg['model'] == 'unet_nested':
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'],
use_bn=cfg['use_bn'], use_ds=cfg['use_ds'])
elif cfg['model'] == 'unet_simple': elif cfg['model'] == 'unet_simple':
net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
else: else:
@@ -68,14 +71,28 @@ def train_net(data_dir,
param_dict = load_checkpoint(cfg['resume_ckpt']) param_dict = load_checkpoint(cfg['resume_ckpt'])
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)


criterion = CrossEntropyWithLogits()
train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute, cfg["crop"],
cfg['img_size'])
if 'use_ds' in cfg and cfg['use_ds']:
criterion = MultiCrossEntropyWithLogits()
else:
criterion = CrossEntropyWithLogits()
if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei":
repeat = 10
dataset_sink_mode = True
per_print_times = 0
train_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], repeat, batch_size,
is_train=True, augment=True, split=0.8, rank=rank,
group_size=group_size)
else:
repeat = epochs
dataset_sink_mode = False
per_print_times = 1
train_dataset, _ = create_dataset(data_dir, repeat, batch_size, True, cross_valid_ind, run_distribute,
cfg["crop"], cfg['img_size'])
train_data_size = train_dataset.get_dataset_size() train_data_size = train_dataset.get_dataset_size()
print("dataset length is:", train_data_size) print("dataset length is:", train_data_size)
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size, ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,
keep_checkpoint_max=cfg['keep_checkpoint_max']) keep_checkpoint_max=cfg['keep_checkpoint_max'])
ckpoint_cb = ModelCheckpoint(prefix='ckpt_unet_medical_adam',
ckpoint_cb = ModelCheckpoint(prefix='ckpt_{}_adam'.format(cfg['model']),
directory='./ckpt_{}/'.format(device_id), directory='./ckpt_{}/'.format(device_id),
config=ckpt_config) config=ckpt_config)


@@ -87,13 +104,11 @@ def train_net(data_dir,
model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer, amp_level="O3") model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer, amp_level="O3")


print("============== Starting Training ==============") print("============== Starting Training ==============")
model.train(1, train_dataset, callbacks=[StepLossTimeMonitor(batch_size=batch_size), ckpoint_cb],
dataset_sink_mode=False)
callbacks = [StepLossTimeMonitor(batch_size=batch_size, per_print_times=per_print_times), ckpoint_cb]
model.train(int(epochs / repeat), train_dataset, callbacks=callbacks, dataset_sink_mode=dataset_sink_mode)
print("============== End Training ==============") print("============== End Training ==============")






def get_args(): def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks', parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
formatter_class=argparse.ArgumentDefaultsHelpFormatter) formatter_class=argparse.ArgumentDefaultsHelpFormatter)


Loading…
Cancel
Save