| @@ -24,7 +24,7 @@ from mindspore import context, Model | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.nn.loss.loss import _Loss | |||
| from src.data_loader import create_dataset | |||
| from src.data_loader import create_dataset, create_cell_nuclei_dataset | |||
| from src.unet_medical import UNetMedical | |||
| from src.unet_nested import NestedUNet, UNet | |||
| from src.config import cfg_unet | |||
| @@ -59,6 +59,7 @@ class dice_coeff(nn.Metric): | |||
| self.clear() | |||
| def clear(self): | |||
| self._dice_coeff_sum = 0 | |||
| self._iou_sum = 0 | |||
| self._samples_num = 0 | |||
| def update(self, *inputs): | |||
| @@ -77,13 +78,15 @@ class dice_coeff(nn.Metric): | |||
| union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten()) | |||
| single_dice_coeff = 2*float(inter)/float(union+1e-6) | |||
| print("single dice coeff is:", single_dice_coeff) | |||
| single_iou = single_dice_coeff / (2 - single_dice_coeff) | |||
| print("single dice coeff is: {}, IOU is: {}".format(single_dice_coeff, single_iou)) | |||
| self._dice_coeff_sum += single_dice_coeff | |||
| self._iou_sum += single_iou | |||
| def eval(self): | |||
| if self._samples_num == 0: | |||
| raise RuntimeError('Total samples num must not be 0.') | |||
| return self._dice_coeff_sum / float(self._samples_num) | |||
| return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num)) | |||
| def test_net(data_dir, | |||
| @@ -93,7 +96,8 @@ def test_net(data_dir, | |||
| if cfg['model'] == 'unet_medical': | |||
| net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) | |||
| elif cfg['model'] == 'unet_nested': | |||
| net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) | |||
| net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'], | |||
| use_bn=cfg['use_bn'], use_ds=False) | |||
| elif cfg['model'] == 'unet_simple': | |||
| net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) | |||
| else: | |||
| @@ -102,13 +106,17 @@ def test_net(data_dir, | |||
| load_param_into_net(net, param_dict) | |||
| criterion = CrossEntropyWithLogits() | |||
| _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, | |||
| do_crop=cfg['crop'], img_size=cfg['img_size']) | |||
| if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei": | |||
| valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False, split=0.8) | |||
| else: | |||
| _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, | |||
| do_crop=cfg['crop'], img_size=cfg['img_size']) | |||
| model = Model(net, loss_fn=criterion, metrics={"dice_coeff": dice_coeff()}) | |||
| print("============== Starting Evaluating ============") | |||
| dice_score = model.eval(valid_dataset, dataset_sink_mode=False) | |||
| print("============== Cross valid dice coeff is:", dice_score) | |||
| eval_score = model.eval(valid_dataset, dataset_sink_mode=False)["dice_coeff"] | |||
| print("============== Cross valid dice coeff is:", eval_score[0]) | |||
| print("============== Cross valid IOU is:", eval_score[1]) | |||
| def get_args(): | |||
| @@ -42,7 +42,8 @@ if __name__ == "__main__": | |||
| if cfg['model'] == 'unet_medical': | |||
| net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) | |||
| elif cfg['model'] == 'unet_nested': | |||
| net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) | |||
| net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'], | |||
| use_bn=cfg['use_bn'], use_ds=False) | |||
| elif cfg['model'] == 'unet_simple': | |||
| net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) | |||
| else: | |||
| @@ -50,6 +50,34 @@ cfg_unet_nested = { | |||
| 'weight_decay': 0.0005, | |||
| 'loss_scale': 1024.0, | |||
| 'FixedLossScaleManager': 1024.0, | |||
| 'use_bn': True, | |||
| 'use_ds': True, | |||
| 'use_deconv': True, | |||
| 'resume': False, | |||
| 'resume_ckpt': './', | |||
| } | |||
| cfg_unet_nested_cell = { | |||
| 'model': 'unet_nested', | |||
| 'dataset': 'Cell_nuclei', | |||
| 'crop': None, | |||
| 'img_size': [96, 96], | |||
| 'lr': 3e-4, | |||
| 'epochs': 200, | |||
| 'distribute_epochs': 1600, | |||
| 'batchsize': 16, | |||
| 'cross_valid_ind': 1, | |||
| 'num_classes': 2, | |||
| 'num_channels': 3, | |||
| 'keep_checkpoint_max': 10, | |||
| 'weight_decay': 0.0005, | |||
| 'loss_scale': 1024.0, | |||
| 'FixedLossScaleManager': 1024.0, | |||
| 'use_bn': True, | |||
| 'use_ds': True, | |||
| 'use_deconv': True, | |||
| 'resume': False, | |||
| 'resume_ckpt': './', | |||
| @@ -15,6 +15,7 @@ | |||
| import os | |||
| from collections import deque | |||
| import cv2 | |||
| import numpy as np | |||
| from PIL import Image, ImageSequence | |||
| import mindspore.dataset as ds | |||
| @@ -23,7 +24,6 @@ from mindspore.dataset.vision.utils import Inter | |||
| from mindspore.communication.management import get_rank, get_group_size | |||
| def _load_multipage_tiff(path): | |||
| """Load tiff images containing many images in the channel dimension""" | |||
| return np.array([np.array(p) for p in ImageSequence.Iterator(Image.open(path))]) | |||
| @@ -164,3 +164,100 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro | |||
| valid_ds = valid_ds.batch(batch_size=1, drop_remainder=True) | |||
| return train_ds, valid_ds | |||
| class CellNucleiDataset: | |||
| """ | |||
| Cell nuclei dataset preprocess class. | |||
| """ | |||
| def __init__(self, data_dir, repeat, is_train=False, split=0.8): | |||
| self.data_dir = data_dir | |||
| self.img_ids = sorted(next(os.walk(self.data_dir))[1]) | |||
| self.train_ids = self.img_ids[:int(len(self.img_ids) * split)] * repeat | |||
| np.random.shuffle(self.train_ids) | |||
| self.val_ids = self.img_ids[int(len(self.img_ids) * split):] | |||
| self.is_train = is_train | |||
| self._preprocess_dataset() | |||
| def _preprocess_dataset(self): | |||
| for img_id in self.img_ids: | |||
| path = os.path.join(self.data_dir, img_id) | |||
| if (not os.path.exists(os.path.join(path, "image.png"))) or \ | |||
| (not os.path.exists(os.path.join(path, "mask.png"))): | |||
| img = cv2.imread(os.path.join(path, "images", img_id + ".png")) | |||
| if len(img.shape) == 2: | |||
| img = np.expand_dims(img, axis=-1) | |||
| img = np.concatenate([img, img, img], axis=-1) | |||
| mask = [] | |||
| for mask_file in next(os.walk(os.path.join(path, "masks")))[2]: | |||
| mask_ = cv2.imread(os.path.join(path, "masks", mask_file), cv2.IMREAD_GRAYSCALE) | |||
| mask.append(mask_) | |||
| mask = np.max(mask, axis=0) | |||
| cv2.imwrite(os.path.join(path, "image.png"), img) | |||
| cv2.imwrite(os.path.join(path, "mask.png"), mask) | |||
| def _read_img_mask(self, img_id): | |||
| path = os.path.join(self.data_dir, img_id) | |||
| img = cv2.imread(os.path.join(path, "image.png")) | |||
| mask = cv2.imread(os.path.join(path, "mask.png"), cv2.IMREAD_GRAYSCALE) | |||
| return img, mask | |||
| def __getitem__(self, index): | |||
| if self.is_train: | |||
| return self._read_img_mask(self.train_ids[index]) | |||
| return self._read_img_mask(self.val_ids[index]) | |||
| @property | |||
| def column_names(self): | |||
| column_names = ['image', 'mask'] | |||
| return column_names | |||
| def __len__(self): | |||
| if self.is_train: | |||
| return len(self.train_ids) | |||
| return len(self.val_ids) | |||
| def preprocess_img_mask(img, mask, img_size, augment=False): | |||
| """ | |||
| Preprocess for cell nuclei dataset. | |||
| Random crop and flip images and masks when augment is True. | |||
| """ | |||
| if augment: | |||
| img_size_w = int(np.random.randint(img_size[0], img_size[0] * 1.5, 1)) | |||
| img_size_h = int(np.random.randint(img_size[1], img_size[1] * 1.5, 1)) | |||
| img = cv2.resize(img, (img_size_w, img_size_h)) | |||
| mask = cv2.resize(mask, (img_size_w, img_size_h)) | |||
| dw = int(np.random.randint(0, img_size_w - img_size[0] + 1, 1)) | |||
| dh = int(np.random.randint(0, img_size_h - img_size[1] + 1, 1)) | |||
| img = img[dh:dh+img_size[1], dw:dw+img_size[0], :] | |||
| mask = mask[dh:dh+img_size[1], dw:dw+img_size[0]] | |||
| if np.random.random() > 0.5: | |||
| flip_code = int(np.random.randint(-1, 2, 1)) | |||
| img = cv2.flip(img, flip_code) | |||
| mask = cv2.flip(mask, flip_code) | |||
| else: | |||
| img = cv2.resize(img, img_size) | |||
| mask = cv2.resize(mask, img_size) | |||
| img = (img.astype(np.float32) - 127.5) / 127.5 | |||
| img = img.transpose(2, 0, 1) | |||
| mask = mask.astype(np.float32) / 255 | |||
| mask = (mask > 0.5).astype(np.int) | |||
| mask = (np.arange(2) == mask[..., None]).astype(int) | |||
| mask = mask.transpose(2, 0, 1).astype(np.float32) | |||
| return img, mask | |||
| def create_cell_nuclei_dataset(data_dir, img_size, repeat, batch_size, is_train=False, augment=False, | |||
| split=0.8, rank=0, group_size=1, python_multiprocessing=True, num_parallel_workers=8): | |||
| """ | |||
| Get generator dataset for cell nuclei dataset. | |||
| """ | |||
| cell_dataset = CellNucleiDataset(data_dir, repeat, is_train, split) | |||
| sampler = ds.DistributedSampler(group_size, rank, shuffle=is_train) | |||
| dataset = ds.GeneratorDataset(cell_dataset, cell_dataset.column_names, sampler=sampler) | |||
| compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, tuple(img_size), augment and is_train)) | |||
| dataset = dataset.map(operations=compose_map_func, input_columns=cell_dataset.column_names, | |||
| output_columns=cell_dataset.column_names, column_order=cell_dataset.column_names, | |||
| python_multiprocessing=python_multiprocessing, | |||
| num_parallel_workers=num_parallel_workers) | |||
| dataset = dataset.batch(batch_size, drop_remainder=is_train) | |||
| dataset = dataset.repeat(1) | |||
| return dataset | |||
| @@ -36,3 +36,15 @@ class CrossEntropyWithLogits(_Loss): | |||
| loss = self.reduce_mean( | |||
| self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, 2)), self.reshape_fn(label, (-1, 2)))) | |||
| return self.get_loss(loss) | |||
| class MultiCrossEntropyWithLogits(nn.Cell): | |||
| def __init__(self): | |||
| super(MultiCrossEntropyWithLogits, self).__init__() | |||
| self.loss = CrossEntropyWithLogits() | |||
| self.squeeze = F.Squeeze() | |||
| def construct(self, logits, label): | |||
| total_loss = 0 | |||
| for i in range(len(logits)): | |||
| total_loss += self.loss(self.squeeze(logits[i:i+1]), label) | |||
| return total_loss | |||
| @@ -16,6 +16,7 @@ | |||
| # Model of UnetPlusPlus | |||
| import mindspore.nn as nn | |||
| import mindspore.ops as P | |||
| from .unet_parts import UnetConv2d, UnetUp | |||
| @@ -63,6 +64,7 @@ class NestedUNet(nn.Cell): | |||
| self.final2 = nn.Conv2d(filters[0], n_class, 1) | |||
| self.final3 = nn.Conv2d(filters[0], n_class, 1) | |||
| self.final4 = nn.Conv2d(filters[0], n_class, 1) | |||
| self.stack = P.Stack(axis=0) | |||
| def construct(self, inputs): | |||
| x00 = self.conv00(inputs) # channel = filters[0] | |||
| @@ -86,13 +88,12 @@ class NestedUNet(nn.Cell): | |||
| x04 = self.up_concat04(x13, x00, x01, x02, x03) # channel = filters[0] | |||
| final1 = self.final1(x01) | |||
| final2 = self.final1(x02) | |||
| final3 = self.final1(x03) | |||
| final4 = self.final1(x04) | |||
| final = (final1 + final2 + final3 + final4) / 4.0 | |||
| final2 = self.final2(x02) | |||
| final3 = self.final3(x03) | |||
| final4 = self.final4(x04) | |||
| if self.use_ds: | |||
| final = self.stack((final1, final2, final3, final4)) | |||
| return final | |||
| return final4 | |||
| @@ -51,9 +51,23 @@ class StepLossTimeMonitor(Callback): | |||
| if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): | |||
| raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( | |||
| cb_params.cur_epoch_num, cur_step_in_epoch)) | |||
| self.losses.append(loss) | |||
| if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: | |||
| # TEST | |||
| print("step: %s, loss is %s, fps is %s" % (cur_step_in_epoch, loss, step_fps), flush=True) | |||
| def epoch_begin(self, run_context): | |||
| self.epoch_start = time.time() | |||
| self.losses = [] | |||
| def epoch_end(self, run_context): | |||
| cb_params = run_context.original_args() | |||
| epoch_cost = time.time() - self.epoch_start | |||
| step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | |||
| step_fps = self.batch_size * 1.0 * step_in_epoch / epoch_cost | |||
| print("epoch: {:3d}, avg loss:{:.4f}, total cost: {:.3f} s, per step fps:{:5.3f}".format( | |||
| cb_params.cur_epoch_num, np.mean(self.losses), epoch_cost, step_fps), flush=True) | |||
| def mask_to_image(mask): | |||
| return Image.fromarray((mask * 255).astype(np.uint8)) | |||
| @@ -21,15 +21,15 @@ import ast | |||
| import mindspore | |||
| import mindspore.nn as nn | |||
| from mindspore import Model, context | |||
| from mindspore.communication.management import init, get_group_size | |||
| from mindspore.communication.management import init, get_group_size, get_rank | |||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.unet_medical import UNetMedical | |||
| from src.unet_nested import NestedUNet, UNet | |||
| from src.data_loader import create_dataset | |||
| from src.loss import CrossEntropyWithLogits | |||
| from src.data_loader import create_dataset, create_cell_nuclei_dataset | |||
| from src.loss import CrossEntropyWithLogits, MultiCrossEntropyWithLogits | |||
| from src.utils import StepLossTimeMonitor | |||
| from src.config import cfg_unet | |||
| @@ -46,10 +46,12 @@ def train_net(data_dir, | |||
| run_distribute=False, | |||
| cfg=None): | |||
| rank = 0 | |||
| group_size = 1 | |||
| if run_distribute: | |||
| init() | |||
| group_size = get_group_size() | |||
| rank = get_rank() | |||
| parallel_mode = ParallelMode.DATA_PARALLEL | |||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, | |||
| device_num=group_size, | |||
| @@ -58,7 +60,8 @@ def train_net(data_dir, | |||
| if cfg['model'] == 'unet_medical': | |||
| net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) | |||
| elif cfg['model'] == 'unet_nested': | |||
| net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) | |||
| net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'], | |||
| use_bn=cfg['use_bn'], use_ds=cfg['use_ds']) | |||
| elif cfg['model'] == 'unet_simple': | |||
| net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) | |||
| else: | |||
| @@ -68,14 +71,28 @@ def train_net(data_dir, | |||
| param_dict = load_checkpoint(cfg['resume_ckpt']) | |||
| load_param_into_net(net, param_dict) | |||
| criterion = CrossEntropyWithLogits() | |||
| train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute, cfg["crop"], | |||
| cfg['img_size']) | |||
| if 'use_ds' in cfg and cfg['use_ds']: | |||
| criterion = MultiCrossEntropyWithLogits() | |||
| else: | |||
| criterion = CrossEntropyWithLogits() | |||
| if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei": | |||
| repeat = 10 | |||
| dataset_sink_mode = True | |||
| per_print_times = 0 | |||
| train_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], repeat, batch_size, | |||
| is_train=True, augment=True, split=0.8, rank=rank, | |||
| group_size=group_size) | |||
| else: | |||
| repeat = epochs | |||
| dataset_sink_mode = False | |||
| per_print_times = 1 | |||
| train_dataset, _ = create_dataset(data_dir, repeat, batch_size, True, cross_valid_ind, run_distribute, | |||
| cfg["crop"], cfg['img_size']) | |||
| train_data_size = train_dataset.get_dataset_size() | |||
| print("dataset length is:", train_data_size) | |||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size, | |||
| keep_checkpoint_max=cfg['keep_checkpoint_max']) | |||
| ckpoint_cb = ModelCheckpoint(prefix='ckpt_unet_medical_adam', | |||
| ckpoint_cb = ModelCheckpoint(prefix='ckpt_{}_adam'.format(cfg['model']), | |||
| directory='./ckpt_{}/'.format(device_id), | |||
| config=ckpt_config) | |||
| @@ -87,13 +104,11 @@ def train_net(data_dir, | |||
| model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer, amp_level="O3") | |||
| print("============== Starting Training ==============") | |||
| model.train(1, train_dataset, callbacks=[StepLossTimeMonitor(batch_size=batch_size), ckpoint_cb], | |||
| dataset_sink_mode=False) | |||
| callbacks = [StepLossTimeMonitor(batch_size=batch_size, per_print_times=per_print_times), ckpoint_cb] | |||
| model.train(int(epochs / repeat), train_dataset, callbacks=callbacks, dataset_sink_mode=dataset_sink_mode) | |||
| print("============== End Training ==============") | |||
| def get_args(): | |||
| parser = argparse.ArgumentParser(description='Train the UNet on images and target masks', | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||