diff --git a/model_zoo/official/cv/unet/eval.py b/model_zoo/official/cv/unet/eval.py index 6b8e746c6e..f3a2cce6f4 100644 --- a/model_zoo/official/cv/unet/eval.py +++ b/model_zoo/official/cv/unet/eval.py @@ -25,7 +25,8 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.nn.loss.loss import _Loss from src.data_loader import create_dataset -from src.unet import UNet +from src.unet_medical import UNetMedical +from src.unet_nested import NestedUNet, UNet from src.config import cfg_unet from scipy.special import softmax @@ -34,8 +35,6 @@ device_id = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) - - class CrossEntropyWithLogits(_Loss): def __init__(self): super(CrossEntropyWithLogits, self).__init__() @@ -64,10 +63,11 @@ class dice_coeff(nn.Metric): def update(self, *inputs): if len(inputs) != 2: - raise ValueError('Mean dice coeffcient need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) + raise ValueError('Mean dice coefficient need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) y_pred = self._convert_data(inputs[0]) y = self._convert_data(inputs[1]) + self._samples_num += y.shape[0] y_pred = y_pred.transpose(0, 2, 3, 1) y = y.transpose(0, 2, 3, 1) @@ -90,13 +90,20 @@ def test_net(data_dir, ckpt_path, cross_valid_ind=1, cfg=None): - - net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) + if cfg['model'] == 'unet_medical': + net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) + elif cfg['model'] == 'unet_nested': + net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) + elif cfg['model'] == 'unet_simple': + net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) + else: + raise ValueError("Unsupported model: {}".format(cfg['model'])) param_dict = load_checkpoint(ckpt_path) load_param_into_net(net, param_dict) criterion = CrossEntropyWithLogits() - _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False) + _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, + do_crop=cfg['crop'], img_size=cfg['img_size']) model = Model(net, loss_fn=criterion, metrics={"dice_coeff": dice_coeff()}) print("============== Starting Evaluating ============") diff --git a/model_zoo/official/cv/unet/export.py b/model_zoo/official/cv/unet/export.py index fc72b65dd8..17eb773f40 100644 --- a/model_zoo/official/cv/unet/export.py +++ b/model_zoo/official/cv/unet/export.py @@ -18,7 +18,8 @@ import numpy as np from mindspore import Tensor, export, load_checkpoint, load_param_into_net, context -from src.unet.unet_model import UNet +from src.unet_medical.unet_model import UNetMedical +from src.unet_nested import NestedUNet, UNet from src.config import cfg_unet as cfg parser = argparse.ArgumentParser(description='unet export') @@ -38,7 +39,14 @@ if args.device_target == "Ascend": context.set_context(device_id=args.device_id) if __name__ == "__main__": - net = UNet(n_channels=cfg["num_channels"], n_classes=cfg["num_classes"]) + if cfg['model'] == 'unet_medical': + net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) + elif cfg['model'] == 'unet_nested': + net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) + elif cfg['model'] == 'unet_simple': + net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) + else: + raise ValueError("Unsupported model: {}".format(cfg['model'])) # return a parameter dict for model param_dict = load_checkpoint(args.ckpt_file) # load the parameter into net diff --git a/model_zoo/official/cv/unet/mindspore_hub_conf.py b/model_zoo/official/cv/unet/mindspore_hub_conf.py index 4e292980da..90955c10a5 100644 --- a/model_zoo/official/cv/unet/mindspore_hub_conf.py +++ b/model_zoo/official/cv/unet/mindspore_hub_conf.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ """hub config.""" -from src.unet import UNet +from src.unet_medical import UNet def create_network(name, *args, **kwargs): if name == "unet2d": diff --git a/model_zoo/official/cv/unet/scripts/run_distribute_train.sh b/model_zoo/official/cv/unet/scripts/run_distribute_train.sh index 807490db21..f00014d27f 100644 --- a/model_zoo/official/cv/unet/scripts/run_distribute_train.sh +++ b/model_zoo/official/cv/unet/scripts/run_distribute_train.sh @@ -14,15 +14,14 @@ # limitations under the License. # ============================================================================ -echo "==============================================================================================================" -echo "Please run the script as: " -echo "bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]" -echo "for example: bash run_distribute_train.sh /absolute/path/to/RANK_TABLE_FILE /absolute/path/to/data" -echo "==============================================================================================================" - if [ $# != 2 ] then + echo "==============================================================================================================" echo "Usage: bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]" + echo "Please run the script as: " + echo "bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]" + echo "for example: bash run_distribute_train.sh /absolute/path/to/RANK_TABLE_FILE /absolute/path/to/data" + echo "==============================================================================================================" exit 1 fi diff --git a/model_zoo/official/cv/unet/scripts/run_standalone_eval.sh b/model_zoo/official/cv/unet/scripts/run_standalone_eval.sh index 2bfccfb331..ded1f908a0 100644 --- a/model_zoo/official/cv/unet/scripts/run_standalone_eval.sh +++ b/model_zoo/official/cv/unet/scripts/run_standalone_eval.sh @@ -14,11 +14,14 @@ # limitations under the License. # ============================================================================ -echo "==============================================================================================================" -echo "Please run the script as: " -echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]" -echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/" -echo "==============================================================================================================" +if [ $# != 2 ] +then + echo "==============================================================================================================" + echo "Please run the script as: " + echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]" + echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/" + echo "==============================================================================================================" +fi export DEVICE_ID=0 python eval.py --data_url=$1 --ckpt_path=$2 > eval.log 2>&1 & \ No newline at end of file diff --git a/model_zoo/official/cv/unet/scripts/run_standalone_train.sh b/model_zoo/official/cv/unet/scripts/run_standalone_train.sh index 3a1594d51f..c8089b5a2b 100644 --- a/model_zoo/official/cv/unet/scripts/run_standalone_train.sh +++ b/model_zoo/official/cv/unet/scripts/run_standalone_train.sh @@ -14,11 +14,14 @@ # limitations under the License. # ============================================================================ -echo "==============================================================================================================" -echo "Please run the script as: " -echo "bash scripts/run_standalone_train.sh [DATASET]" -echo "for example: bash run_standalone_train.sh /path/to/data/" -echo "==============================================================================================================" +if [ $# != 1 ] +then + echo "==============================================================================================================" + echo "Please run the script as: " + echo "bash scripts/run_standalone_train.sh [DATASET]" + echo "for example: bash run_standalone_train.sh /path/to/data/" + echo "==============================================================================================================" +fi export DEVICE_ID=0 python train.py --data_url=$1 > train.log 2>&1 & \ No newline at end of file diff --git a/model_zoo/official/cv/unet/src/config.py b/model_zoo/official/cv/unet/src/config.py index 9f8eba7b45..2e34a7bd6e 100644 --- a/model_zoo/official/cv/unet/src/config.py +++ b/model_zoo/official/cv/unet/src/config.py @@ -13,7 +13,10 @@ # limitations under the License. # ============================================================================ -cfg_unet = { +cfg_unet_medical = { + 'model': 'unet_medical', + 'crop': [388 / 572, 388 / 572], + 'img_size': [572, 572], 'lr': 0.0001, 'epochs': 400, 'distribute_epochs': 1600, @@ -30,3 +33,47 @@ cfg_unet = { 'resume': False, 'resume_ckpt': './', } + +cfg_unet_nested = { + 'model': 'unet_nested', + 'crop': None, + 'img_size': [576, 576], + 'lr': 0.0001, + 'epochs': 400, + 'distribute_epochs': 1600, + 'batchsize': 16, + 'cross_valid_ind': 1, + 'num_classes': 2, + 'num_channels': 1, + + 'keep_checkpoint_max': 10, + 'weight_decay': 0.0005, + 'loss_scale': 1024.0, + 'FixedLossScaleManager': 1024.0, + + 'resume': False, + 'resume_ckpt': './', +} + +cfg_unet_simple = { + 'model': 'unet_simple', + 'crop': None, + 'img_size': [576, 576], + 'lr': 0.0001, + 'epochs': 400, + 'distribute_epochs': 1600, + 'batchsize': 16, + 'cross_valid_ind': 1, + 'num_classes': 2, + 'num_channels': 1, + + 'keep_checkpoint_max': 10, + 'weight_decay': 0.0005, + 'loss_scale': 1024.0, + 'FixedLossScaleManager': 1024.0, + + 'resume': False, + 'resume_ckpt': './', +} + +cfg_unet = cfg_unet_medical diff --git a/model_zoo/official/cv/unet/src/data_loader.py b/model_zoo/official/cv/unet/src/data_loader.py index dbf5b664cf..32f056628e 100644 --- a/model_zoo/official/cv/unet/src/data_loader.py +++ b/model_zoo/official/cv/unet/src/data_loader.py @@ -82,7 +82,8 @@ def train_data_augmentation(img, mask): return img, mask -def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cross_val_ind=1, run_distribute=False): +def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cross_val_ind=1, run_distribute=False, + do_crop=None, img_size=None): images = _load_multipage_tiff(os.path.join(data_dir, 'train-volume.tif')) masks = _load_multipage_tiff(os.path.join(data_dir, 'train-labels.tif')) @@ -121,8 +122,12 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro ds_valid_images = ds.NumpySlicesDataset(data=valid_image_data, sampler=None, shuffle=False) ds_valid_masks = ds.NumpySlicesDataset(data=valid_mask_data, sampler=None, shuffle=False) - c_resize_op = c_vision.Resize(size=(388, 388), interpolation=Inter.BILINEAR) - c_pad = c_vision.Pad(padding=92) + if do_crop: + resize_size = [int(img_size[x] * do_crop[x]) for x in range(len(img_size))] + else: + resize_size = img_size + c_resize_op = c_vision.Resize(size=(resize_size[0], resize_size[1]), interpolation=Inter.BILINEAR) + c_pad = c_vision.Pad(padding=(img_size[0] - resize_size[0]) // 2) c_rescale_image = c_vision.Rescale(1.0/127.5, -1) c_rescale_mask = c_vision.Rescale(1.0/255.0, 0) @@ -136,12 +141,13 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro train_ds = train_ds.project(columns=["image", "mask"]) if augment: augment_process = train_data_augmentation - c_resize_op = c_vision.Resize(size=(572, 572), interpolation=Inter.BILINEAR) + c_resize_op = c_vision.Resize(size=(img_size[0], img_size[1]), interpolation=Inter.BILINEAR) train_ds = train_ds.map(input_columns=["image", "mask"], operations=augment_process) train_ds = train_ds.map(input_columns="image", operations=c_resize_op) train_ds = train_ds.map(input_columns="mask", operations=c_resize_op) - train_ds = train_ds.map(input_columns="mask", operations=c_center_crop) + if do_crop: + train_ds = train_ds.map(input_columns="mask", operations=c_center_crop) post_process = data_post_process train_ds = train_ds.map(input_columns=["image", "mask"], operations=post_process) train_ds = train_ds.shuffle(repeat*24) @@ -151,7 +157,8 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro valid_mask_ds = ds_valid_masks.map(input_columns="mask", operations=c_trans_normalize_mask) valid_ds = ds.zip((valid_image_ds, valid_mask_ds)) valid_ds = valid_ds.project(columns=["image", "mask"]) - valid_ds = valid_ds.map(input_columns="mask", operations=c_center_crop) + if do_crop: + valid_ds = valid_ds.map(input_columns="mask", operations=c_center_crop) post_process = data_post_process valid_ds = valid_ds.map(input_columns=["image", "mask"], operations=post_process) valid_ds = valid_ds.batch(batch_size=1, drop_remainder=True) diff --git a/model_zoo/official/cv/unet/src/unet/__init__.py b/model_zoo/official/cv/unet/src/unet_medical/__init__.py similarity index 94% rename from model_zoo/official/cv/unet/src/unet/__init__.py rename to model_zoo/official/cv/unet/src/unet_medical/__init__.py index de5b46ac43..da537aac8b 100644 --- a/model_zoo/official/cv/unet/src/unet/__init__.py +++ b/model_zoo/official/cv/unet/src/unet_medical/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. # ============================================================================ -from .unet_model import UNet +from .unet_model import UNetMedical diff --git a/model_zoo/official/cv/unet/src/unet/unet_model.py b/model_zoo/official/cv/unet/src/unet_medical/unet_model.py similarity index 90% rename from model_zoo/official/cv/unet/src/unet/unet_model.py rename to model_zoo/official/cv/unet/src/unet_medical/unet_model.py index 2949d51cab..26d5f271c6 100644 --- a/model_zoo/official/cv/unet/src/unet/unet_model.py +++ b/model_zoo/official/cv/unet/src/unet_medical/unet_model.py @@ -13,12 +13,12 @@ # limitations under the License. # ============================================================================ -from src.unet.unet_parts import DoubleConv, Down, Up1, Up2, Up3, Up4, OutConv +from src.unet_medical.unet_parts import DoubleConv, Down, Up1, Up2, Up3, Up4, OutConv import mindspore.nn as nn -class UNet(nn.Cell): +class UNetMedical(nn.Cell): def __init__(self, n_channels, n_classes): - super(UNet, self).__init__() + super(UNetMedical, self).__init__() self.n_channels = n_channels self.n_classes = n_classes self.inc = DoubleConv(n_channels, 64) diff --git a/model_zoo/official/cv/unet/src/unet/unet_parts.py b/model_zoo/official/cv/unet/src/unet_medical/unet_parts.py similarity index 100% rename from model_zoo/official/cv/unet/src/unet/unet_parts.py rename to model_zoo/official/cv/unet/src/unet_medical/unet_parts.py diff --git a/model_zoo/official/cv/unet/src/unet_nested/__init__.py b/model_zoo/official/cv/unet/src/unet_nested/__init__.py new file mode 100644 index 0000000000..6a392415be --- /dev/null +++ b/model_zoo/official/cv/unet/src/unet_nested/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# less required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from .unet_model import NestedUNet, UNet diff --git a/model_zoo/official/cv/unet/src/unet_nested/unet_model.py b/model_zoo/official/cv/unet/src/unet_nested/unet_model.py new file mode 100644 index 0000000000..bbfba469db --- /dev/null +++ b/model_zoo/official/cv/unet/src/unet_nested/unet_model.py @@ -0,0 +1,146 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# less required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# Model of UnetPlusPlus + +import mindspore.nn as nn +from .unet_parts import UnetConv2d, UnetUp + + +class NestedUNet(nn.Cell): + """ + Nested unet + """ + def __init__(self, in_channel, n_class=2, feature_scale=2, use_deconv=True, use_bn=True, use_ds=True): + super(NestedUNet, self).__init__() + self.in_channel = in_channel + self.n_class = n_class + self.feature_scale = feature_scale + self.use_deconv = use_deconv + self.use_bn = use_bn + self.use_ds = use_ds + + filters = [64, 128, 256, 512, 1024] + filters = [int(x / self.feature_scale) for x in filters] + + # Down Sample + self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same") + self.conv00 = UnetConv2d(self.in_channel, filters[0], self.use_bn) + self.conv10 = UnetConv2d(filters[0], filters[1], self.use_bn) + self.conv20 = UnetConv2d(filters[1], filters[2], self.use_bn) + self.conv30 = UnetConv2d(filters[2], filters[3], self.use_bn) + self.conv40 = UnetConv2d(filters[3], filters[4], self.use_bn) + + # Up Sample + self.up_concat01 = UnetUp(filters[1], filters[0], self.use_deconv, 2) + self.up_concat11 = UnetUp(filters[2], filters[1], self.use_deconv, 2) + self.up_concat21 = UnetUp(filters[3], filters[2], self.use_deconv, 2) + self.up_concat31 = UnetUp(filters[4], filters[3], self.use_deconv, 2) + + self.up_concat02 = UnetUp(filters[1], filters[0], self.use_deconv, 3) + self.up_concat12 = UnetUp(filters[2], filters[1], self.use_deconv, 3) + self.up_concat22 = UnetUp(filters[3], filters[2], self.use_deconv, 3) + + self.up_concat03 = UnetUp(filters[1], filters[0], self.use_deconv, 4) + self.up_concat13 = UnetUp(filters[2], filters[1], self.use_deconv, 4) + + self.up_concat04 = UnetUp(filters[1], filters[0], self.use_deconv, 5) + + # Finale Convolution + self.final1 = nn.Conv2d(filters[0], n_class, 1) + self.final2 = nn.Conv2d(filters[0], n_class, 1) + self.final3 = nn.Conv2d(filters[0], n_class, 1) + self.final4 = nn.Conv2d(filters[0], n_class, 1) + + def construct(self, inputs): + x00 = self.conv00(inputs) # channel = filters[0] + x10 = self.conv10(self.maxpool(x00)) # channel = filters[1] + x20 = self.conv20(self.maxpool(x10)) # channel = filters[2] + x30 = self.conv30(self.maxpool(x20)) # channel = filters[3] + x40 = self.conv40(self.maxpool(x30)) # channel = filters[4] + + x01 = self.up_concat01(x10, x00) # channel = filters[0] + x11 = self.up_concat11(x20, x10) # channel = filters[1] + x21 = self.up_concat21(x30, x20) # channel = filters[2] + x31 = self.up_concat31(x40, x30) # channel = filters[3] + + x02 = self.up_concat02(x11, x00, x01) # channel = filters[0] + x12 = self.up_concat12(x21, x10, x11) # channel = filters[1] + x22 = self.up_concat22(x31, x20, x21) # channel = filters[2] + + x03 = self.up_concat03(x12, x00, x01, x02) # channel = filters[0] + x13 = self.up_concat13(x22, x10, x11, x12) # channel = filters[1] + + x04 = self.up_concat04(x13, x00, x01, x02, x03) # channel = filters[0] + + final1 = self.final1(x01) + final2 = self.final1(x02) + final3 = self.final1(x03) + final4 = self.final1(x04) + + final = (final1 + final2 + final3 + final4) / 4.0 + + if self.use_ds: + return final + return final4 + + +class UNet(nn.Cell): + """ + Simple UNet with skip connection + """ + def __init__(self, in_channel, n_class=2, feature_scale=2, use_deconv=True, use_bn=True): + super(UNet, self).__init__() + self.in_channel = in_channel + self.n_class = n_class + self.feature_scale = feature_scale + self.use_deconv = use_deconv + self.use_bn = use_bn + + filters = [64, 128, 256, 512, 1024] + filters = [int(x / self.feature_scale) for x in filters] + + # Down Sample + self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same") + self.conv0 = UnetConv2d(self.in_channel, filters[0], self.use_bn) + self.conv1 = UnetConv2d(filters[0], filters[1], self.use_bn) + self.conv2 = UnetConv2d(filters[1], filters[2], self.use_bn) + self.conv3 = UnetConv2d(filters[2], filters[3], self.use_bn) + self.conv4 = UnetConv2d(filters[3], filters[4], self.use_bn) + + # Up Sample + self.up_concat1 = UnetUp(filters[1], filters[0], self.use_deconv, 2) + self.up_concat2 = UnetUp(filters[2], filters[1], self.use_deconv, 2) + self.up_concat3 = UnetUp(filters[3], filters[2], self.use_deconv, 2) + self.up_concat4 = UnetUp(filters[4], filters[3], self.use_deconv, 2) + + # Finale Convolution + self.final = nn.Conv2d(filters[0], n_class, 1) + + def construct(self, inputs): + x0 = self.conv0(inputs) # channel = filters[0] + x1 = self.conv1(self.maxpool(x0)) # channel = filters[1] + x2 = self.conv2(self.maxpool(x1)) # channel = filters[2] + x3 = self.conv3(self.maxpool(x2)) # channel = filters[3] + x4 = self.conv4(self.maxpool(x3)) # channel = filters[4] + + up4 = self.up_concat4(x4, x3) + up3 = self.up_concat3(up4, x2) + up2 = self.up_concat2(up3, x1) + up1 = self.up_concat1(up2, x0) + + final = self.final(up1) + + return final diff --git a/model_zoo/official/cv/unet/src/unet_nested/unet_parts.py b/model_zoo/official/cv/unet/src/unet_nested/unet_parts.py new file mode 100644 index 0000000000..8c072a24b8 --- /dev/null +++ b/model_zoo/official/cv/unet/src/unet_nested/unet_parts.py @@ -0,0 +1,81 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# less required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" Parts of the U-Net-PlusPlus model """ + +import mindspore.nn as nn +import mindspore.ops.functional as F +import mindspore.ops.operations as P + + +def conv_bn_relu(in_channel, out_channel, use_bn=True, kernel_size=3, stride=1, pad_mode="same", activation='relu'): + output = [] + output.append(nn.Conv2d(in_channel, out_channel, kernel_size, stride, pad_mode=pad_mode)) + if use_bn: + output.append(nn.BatchNorm2d(out_channel)) + if activation: + output.append(nn.get_activation(activation)) + return nn.SequentialCell(output) + + +class UnetConv2d(nn.Cell): + """ + Convolution block in Unet, usually double conv. + """ + def __init__(self, in_channel, out_channel, use_bn=True, num_layer=2, kernel_size=3, stride=1, padding='same'): + super(UnetConv2d, self).__init__() + self.num_layer = num_layer + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.in_channel = in_channel + self.out_channel = out_channel + + convs = [] + for _ in range(num_layer): + convs.append(conv_bn_relu(in_channel, out_channel, use_bn, kernel_size, stride, padding, "relu")) + in_channel = out_channel + + self.convs = nn.SequentialCell(convs) + + def construct(self, inputs): + x = self.convs(inputs) + return x + + +class UnetUp(nn.Cell): + """ + Upsampling high_feature with factor=2 and concat with low feature + """ + def __init__(self, in_channel, out_channel, use_deconv, n_concat=2): + super(UnetUp, self).__init__() + self.conv = UnetConv2d(in_channel + (n_concat - 2) * out_channel, out_channel, False) + self.concat = P.Concat(axis=1) + self.use_deconv = use_deconv + if use_deconv: + self.up_conv = nn.Conv2dTranspose(in_channel, out_channel, kernel_size=2, stride=2, pad_mode="same") + else: + self.up_conv = nn.Conv2d(in_channel, out_channel, 1) + + def construct(self, high_feature, *low_feature): + if self.use_deconv: + output = self.up_conv(high_feature) + else: + _, _, h, w = F.shape(high_feature) + output = P.ResizeBilinear((h * 2, w * 2))(high_feature) + output = self.up_conv(output) + for feature in low_feature: + output = self.concat((output, feature)) + return self.conv(output) diff --git a/model_zoo/official/cv/unet/src/utils.py b/model_zoo/official/cv/unet/src/utils.py index 8be84a1680..d6bb9d653b 100644 --- a/model_zoo/official/cv/unet/src/utils.py +++ b/model_zoo/official/cv/unet/src/utils.py @@ -15,6 +15,7 @@ import time import numpy as np +from PIL import Image from mindspore.train.callback import Callback from mindspore.common.tensor import Tensor @@ -53,3 +54,6 @@ class StepLossTimeMonitor(Callback): if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: # TEST print("step: %s, loss is %s, fps is %s" % (cur_step_in_epoch, loss, step_fps), flush=True) + +def mask_to_image(mask): + return Image.fromarray((mask * 255).astype(np.uint8)) diff --git a/model_zoo/official/cv/unet/train.py b/model_zoo/official/cv/unet/train.py index dedc94899e..37615f7003 100644 --- a/model_zoo/official/cv/unet/train.py +++ b/model_zoo/official/cv/unet/train.py @@ -26,7 +26,8 @@ from mindspore.train.callback import CheckpointConfig, ModelCheckpoint from mindspore.context import ParallelMode from mindspore.train.serialization import load_checkpoint, load_param_into_net -from src.unet import UNet +from src.unet_medical import UNetMedical +from src.unet_nested import NestedUNet, UNet from src.data_loader import create_dataset from src.loss import CrossEntropyWithLogits from src.utils import StepLossTimeMonitor @@ -53,14 +54,23 @@ def train_net(data_dir, context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, gradients_mean=False) - net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) + + if cfg['model'] == 'unet_medical': + net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) + elif cfg['model'] == 'unet_nested': + net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) + elif cfg['model'] == 'unet_simple': + net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) + else: + raise ValueError("Unsupported model: {}".format(cfg['model'])) if cfg['resume']: param_dict = load_checkpoint(cfg['resume_ckpt']) load_param_into_net(net, param_dict) criterion = CrossEntropyWithLogits() - train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute) + train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute, cfg["crop"], + cfg['img_size']) train_data_size = train_dataset.get_dataset_size() print("dataset length is:", train_data_size) ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,