Browse Source

!12685 Add nested-unet

From: @c_34
Reviewed-by: @wuxuejian
Signed-off-by: @wuxuejian
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
b9fac815bc
16 changed files with 371 additions and 40 deletions
  1. +14
    -7
      model_zoo/official/cv/unet/eval.py
  2. +10
    -2
      model_zoo/official/cv/unet/export.py
  3. +1
    -1
      model_zoo/official/cv/unet/mindspore_hub_conf.py
  4. +5
    -6
      model_zoo/official/cv/unet/scripts/run_distribute_train.sh
  5. +8
    -5
      model_zoo/official/cv/unet/scripts/run_standalone_eval.sh
  6. +8
    -5
      model_zoo/official/cv/unet/scripts/run_standalone_train.sh
  7. +48
    -1
      model_zoo/official/cv/unet/src/config.py
  8. +13
    -6
      model_zoo/official/cv/unet/src/data_loader.py
  9. +1
    -1
      model_zoo/official/cv/unet/src/unet_medical/__init__.py
  10. +3
    -3
      model_zoo/official/cv/unet/src/unet_medical/unet_model.py
  11. +0
    -0
      model_zoo/official/cv/unet/src/unet_medical/unet_parts.py
  12. +16
    -0
      model_zoo/official/cv/unet/src/unet_nested/__init__.py
  13. +146
    -0
      model_zoo/official/cv/unet/src/unet_nested/unet_model.py
  14. +81
    -0
      model_zoo/official/cv/unet/src/unet_nested/unet_parts.py
  15. +4
    -0
      model_zoo/official/cv/unet/src/utils.py
  16. +13
    -3
      model_zoo/official/cv/unet/train.py

+ 14
- 7
model_zoo/official/cv/unet/eval.py View File

@@ -25,7 +25,8 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.loss.loss import _Loss from mindspore.nn.loss.loss import _Loss


from src.data_loader import create_dataset from src.data_loader import create_dataset
from src.unet import UNet
from src.unet_medical import UNetMedical
from src.unet_nested import NestedUNet, UNet
from src.config import cfg_unet from src.config import cfg_unet


from scipy.special import softmax from scipy.special import softmax
@@ -34,8 +35,6 @@ device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)






class CrossEntropyWithLogits(_Loss): class CrossEntropyWithLogits(_Loss):
def __init__(self): def __init__(self):
super(CrossEntropyWithLogits, self).__init__() super(CrossEntropyWithLogits, self).__init__()
@@ -64,10 +63,11 @@ class dice_coeff(nn.Metric):


def update(self, *inputs): def update(self, *inputs):
if len(inputs) != 2: if len(inputs) != 2:
raise ValueError('Mean dice coeffcient need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
raise ValueError('Mean dice coefficient need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))


y_pred = self._convert_data(inputs[0]) y_pred = self._convert_data(inputs[0])
y = self._convert_data(inputs[1]) y = self._convert_data(inputs[1])

self._samples_num += y.shape[0] self._samples_num += y.shape[0]
y_pred = y_pred.transpose(0, 2, 3, 1) y_pred = y_pred.transpose(0, 2, 3, 1)
y = y.transpose(0, 2, 3, 1) y = y.transpose(0, 2, 3, 1)
@@ -90,13 +90,20 @@ def test_net(data_dir,
ckpt_path, ckpt_path,
cross_valid_ind=1, cross_valid_ind=1,
cfg=None): cfg=None):

net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
if cfg['model'] == 'unet_medical':
net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
elif cfg['model'] == 'unet_nested':
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
elif cfg['model'] == 'unet_simple':
net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
else:
raise ValueError("Unsupported model: {}".format(cfg['model']))
param_dict = load_checkpoint(ckpt_path) param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)


criterion = CrossEntropyWithLogits() criterion = CrossEntropyWithLogits()
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False)
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False,
do_crop=cfg['crop'], img_size=cfg['img_size'])
model = Model(net, loss_fn=criterion, metrics={"dice_coeff": dice_coeff()}) model = Model(net, loss_fn=criterion, metrics={"dice_coeff": dice_coeff()})


print("============== Starting Evaluating ============") print("============== Starting Evaluating ============")


+ 10
- 2
model_zoo/official/cv/unet/export.py View File

@@ -18,7 +18,8 @@ import numpy as np


from mindspore import Tensor, export, load_checkpoint, load_param_into_net, context from mindspore import Tensor, export, load_checkpoint, load_param_into_net, context


from src.unet.unet_model import UNet
from src.unet_medical.unet_model import UNetMedical
from src.unet_nested import NestedUNet, UNet
from src.config import cfg_unet as cfg from src.config import cfg_unet as cfg


parser = argparse.ArgumentParser(description='unet export') parser = argparse.ArgumentParser(description='unet export')
@@ -38,7 +39,14 @@ if args.device_target == "Ascend":
context.set_context(device_id=args.device_id) context.set_context(device_id=args.device_id)


if __name__ == "__main__": if __name__ == "__main__":
net = UNet(n_channels=cfg["num_channels"], n_classes=cfg["num_classes"])
if cfg['model'] == 'unet_medical':
net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
elif cfg['model'] == 'unet_nested':
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
elif cfg['model'] == 'unet_simple':
net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
else:
raise ValueError("Unsupported model: {}".format(cfg['model']))
# return a parameter dict for model # return a parameter dict for model
param_dict = load_checkpoint(args.ckpt_file) param_dict = load_checkpoint(args.ckpt_file)
# load the parameter into net # load the parameter into net


+ 1
- 1
model_zoo/official/cv/unet/mindspore_hub_conf.py View File

@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""hub config.""" """hub config."""
from src.unet import UNet
from src.unet_medical import UNet


def create_network(name, *args, **kwargs): def create_network(name, *args, **kwargs):
if name == "unet2d": if name == "unet2d":


+ 5
- 6
model_zoo/official/cv/unet/scripts/run_distribute_train.sh View File

@@ -14,15 +14,14 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================


echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]"
echo "for example: bash run_distribute_train.sh /absolute/path/to/RANK_TABLE_FILE /absolute/path/to/data"
echo "=============================================================================================================="

if [ $# != 2 ] if [ $# != 2 ]
then then
echo "=============================================================================================================="
echo "Usage: bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]" echo "Usage: bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]"
echo "Please run the script as: "
echo "bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]"
echo "for example: bash run_distribute_train.sh /absolute/path/to/RANK_TABLE_FILE /absolute/path/to/data"
echo "=============================================================================================================="
exit 1 exit 1
fi fi




+ 8
- 5
model_zoo/official/cv/unet/scripts/run_standalone_eval.sh View File

@@ -14,11 +14,14 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================


echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]"
echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/"
echo "=============================================================================================================="
if [ $# != 2 ]
then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]"
echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/"
echo "=============================================================================================================="
fi


export DEVICE_ID=0 export DEVICE_ID=0
python eval.py --data_url=$1 --ckpt_path=$2 > eval.log 2>&1 & python eval.py --data_url=$1 --ckpt_path=$2 > eval.log 2>&1 &

+ 8
- 5
model_zoo/official/cv/unet/scripts/run_standalone_train.sh View File

@@ -14,11 +14,14 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================


echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_train.sh [DATASET]"
echo "for example: bash run_standalone_train.sh /path/to/data/"
echo "=============================================================================================================="
if [ $# != 1 ]
then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_train.sh [DATASET]"
echo "for example: bash run_standalone_train.sh /path/to/data/"
echo "=============================================================================================================="
fi


export DEVICE_ID=0 export DEVICE_ID=0
python train.py --data_url=$1 > train.log 2>&1 & python train.py --data_url=$1 > train.log 2>&1 &

+ 48
- 1
model_zoo/official/cv/unet/src/config.py View File

@@ -13,7 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
cfg_unet = {
cfg_unet_medical = {
'model': 'unet_medical',
'crop': [388 / 572, 388 / 572],
'img_size': [572, 572],
'lr': 0.0001, 'lr': 0.0001,
'epochs': 400, 'epochs': 400,
'distribute_epochs': 1600, 'distribute_epochs': 1600,
@@ -30,3 +33,47 @@ cfg_unet = {
'resume': False, 'resume': False,
'resume_ckpt': './', 'resume_ckpt': './',
} }
cfg_unet_nested = {
'model': 'unet_nested',
'crop': None,
'img_size': [576, 576],
'lr': 0.0001,
'epochs': 400,
'distribute_epochs': 1600,
'batchsize': 16,
'cross_valid_ind': 1,
'num_classes': 2,
'num_channels': 1,
'keep_checkpoint_max': 10,
'weight_decay': 0.0005,
'loss_scale': 1024.0,
'FixedLossScaleManager': 1024.0,
'resume': False,
'resume_ckpt': './',
}
cfg_unet_simple = {
'model': 'unet_simple',
'crop': None,
'img_size': [576, 576],
'lr': 0.0001,
'epochs': 400,
'distribute_epochs': 1600,
'batchsize': 16,
'cross_valid_ind': 1,
'num_classes': 2,
'num_channels': 1,
'keep_checkpoint_max': 10,
'weight_decay': 0.0005,
'loss_scale': 1024.0,
'FixedLossScaleManager': 1024.0,
'resume': False,
'resume_ckpt': './',
}
cfg_unet = cfg_unet_medical

+ 13
- 6
model_zoo/official/cv/unet/src/data_loader.py View File

@@ -82,7 +82,8 @@ def train_data_augmentation(img, mask):
return img, mask return img, mask




def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cross_val_ind=1, run_distribute=False):
def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cross_val_ind=1, run_distribute=False,
do_crop=None, img_size=None):


images = _load_multipage_tiff(os.path.join(data_dir, 'train-volume.tif')) images = _load_multipage_tiff(os.path.join(data_dir, 'train-volume.tif'))
masks = _load_multipage_tiff(os.path.join(data_dir, 'train-labels.tif')) masks = _load_multipage_tiff(os.path.join(data_dir, 'train-labels.tif'))
@@ -121,8 +122,12 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro
ds_valid_images = ds.NumpySlicesDataset(data=valid_image_data, sampler=None, shuffle=False) ds_valid_images = ds.NumpySlicesDataset(data=valid_image_data, sampler=None, shuffle=False)
ds_valid_masks = ds.NumpySlicesDataset(data=valid_mask_data, sampler=None, shuffle=False) ds_valid_masks = ds.NumpySlicesDataset(data=valid_mask_data, sampler=None, shuffle=False)


c_resize_op = c_vision.Resize(size=(388, 388), interpolation=Inter.BILINEAR)
c_pad = c_vision.Pad(padding=92)
if do_crop:
resize_size = [int(img_size[x] * do_crop[x]) for x in range(len(img_size))]
else:
resize_size = img_size
c_resize_op = c_vision.Resize(size=(resize_size[0], resize_size[1]), interpolation=Inter.BILINEAR)
c_pad = c_vision.Pad(padding=(img_size[0] - resize_size[0]) // 2)
c_rescale_image = c_vision.Rescale(1.0/127.5, -1) c_rescale_image = c_vision.Rescale(1.0/127.5, -1)
c_rescale_mask = c_vision.Rescale(1.0/255.0, 0) c_rescale_mask = c_vision.Rescale(1.0/255.0, 0)


@@ -136,12 +141,13 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro
train_ds = train_ds.project(columns=["image", "mask"]) train_ds = train_ds.project(columns=["image", "mask"])
if augment: if augment:
augment_process = train_data_augmentation augment_process = train_data_augmentation
c_resize_op = c_vision.Resize(size=(572, 572), interpolation=Inter.BILINEAR)
c_resize_op = c_vision.Resize(size=(img_size[0], img_size[1]), interpolation=Inter.BILINEAR)
train_ds = train_ds.map(input_columns=["image", "mask"], operations=augment_process) train_ds = train_ds.map(input_columns=["image", "mask"], operations=augment_process)
train_ds = train_ds.map(input_columns="image", operations=c_resize_op) train_ds = train_ds.map(input_columns="image", operations=c_resize_op)
train_ds = train_ds.map(input_columns="mask", operations=c_resize_op) train_ds = train_ds.map(input_columns="mask", operations=c_resize_op)


train_ds = train_ds.map(input_columns="mask", operations=c_center_crop)
if do_crop:
train_ds = train_ds.map(input_columns="mask", operations=c_center_crop)
post_process = data_post_process post_process = data_post_process
train_ds = train_ds.map(input_columns=["image", "mask"], operations=post_process) train_ds = train_ds.map(input_columns=["image", "mask"], operations=post_process)
train_ds = train_ds.shuffle(repeat*24) train_ds = train_ds.shuffle(repeat*24)
@@ -151,7 +157,8 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro
valid_mask_ds = ds_valid_masks.map(input_columns="mask", operations=c_trans_normalize_mask) valid_mask_ds = ds_valid_masks.map(input_columns="mask", operations=c_trans_normalize_mask)
valid_ds = ds.zip((valid_image_ds, valid_mask_ds)) valid_ds = ds.zip((valid_image_ds, valid_mask_ds))
valid_ds = valid_ds.project(columns=["image", "mask"]) valid_ds = valid_ds.project(columns=["image", "mask"])
valid_ds = valid_ds.map(input_columns="mask", operations=c_center_crop)
if do_crop:
valid_ds = valid_ds.map(input_columns="mask", operations=c_center_crop)
post_process = data_post_process post_process = data_post_process
valid_ds = valid_ds.map(input_columns=["image", "mask"], operations=post_process) valid_ds = valid_ds.map(input_columns=["image", "mask"], operations=post_process)
valid_ds = valid_ds.batch(batch_size=1, drop_remainder=True) valid_ds = valid_ds.batch(batch_size=1, drop_remainder=True)


model_zoo/official/cv/unet/src/unet/__init__.py → model_zoo/official/cv/unet/src/unet_medical/__init__.py View File

@@ -13,4 +13,4 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================


from .unet_model import UNet
from .unet_model import UNetMedical

model_zoo/official/cv/unet/src/unet/unet_model.py → model_zoo/official/cv/unet/src/unet_medical/unet_model.py View File

@@ -13,12 +13,12 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================


from src.unet.unet_parts import DoubleConv, Down, Up1, Up2, Up3, Up4, OutConv
from src.unet_medical.unet_parts import DoubleConv, Down, Up1, Up2, Up3, Up4, OutConv
import mindspore.nn as nn import mindspore.nn as nn


class UNet(nn.Cell):
class UNetMedical(nn.Cell):
def __init__(self, n_channels, n_classes): def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
super(UNetMedical, self).__init__()
self.n_channels = n_channels self.n_channels = n_channels
self.n_classes = n_classes self.n_classes = n_classes
self.inc = DoubleConv(n_channels, 64) self.inc = DoubleConv(n_channels, 64)

model_zoo/official/cv/unet/src/unet/unet_parts.py → model_zoo/official/cv/unet/src/unet_medical/unet_parts.py View File


+ 16
- 0
model_zoo/official/cv/unet/src/unet_nested/__init__.py View File

@@ -0,0 +1,16 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

from .unet_model import NestedUNet, UNet

+ 146
- 0
model_zoo/official/cv/unet/src/unet_nested/unet_model.py View File

@@ -0,0 +1,146 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

# Model of UnetPlusPlus

import mindspore.nn as nn
from .unet_parts import UnetConv2d, UnetUp


class NestedUNet(nn.Cell):
"""
Nested unet
"""
def __init__(self, in_channel, n_class=2, feature_scale=2, use_deconv=True, use_bn=True, use_ds=True):
super(NestedUNet, self).__init__()
self.in_channel = in_channel
self.n_class = n_class
self.feature_scale = feature_scale
self.use_deconv = use_deconv
self.use_bn = use_bn
self.use_ds = use_ds

filters = [64, 128, 256, 512, 1024]
filters = [int(x / self.feature_scale) for x in filters]

# Down Sample
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same")
self.conv00 = UnetConv2d(self.in_channel, filters[0], self.use_bn)
self.conv10 = UnetConv2d(filters[0], filters[1], self.use_bn)
self.conv20 = UnetConv2d(filters[1], filters[2], self.use_bn)
self.conv30 = UnetConv2d(filters[2], filters[3], self.use_bn)
self.conv40 = UnetConv2d(filters[3], filters[4], self.use_bn)

# Up Sample
self.up_concat01 = UnetUp(filters[1], filters[0], self.use_deconv, 2)
self.up_concat11 = UnetUp(filters[2], filters[1], self.use_deconv, 2)
self.up_concat21 = UnetUp(filters[3], filters[2], self.use_deconv, 2)
self.up_concat31 = UnetUp(filters[4], filters[3], self.use_deconv, 2)

self.up_concat02 = UnetUp(filters[1], filters[0], self.use_deconv, 3)
self.up_concat12 = UnetUp(filters[2], filters[1], self.use_deconv, 3)
self.up_concat22 = UnetUp(filters[3], filters[2], self.use_deconv, 3)

self.up_concat03 = UnetUp(filters[1], filters[0], self.use_deconv, 4)
self.up_concat13 = UnetUp(filters[2], filters[1], self.use_deconv, 4)

self.up_concat04 = UnetUp(filters[1], filters[0], self.use_deconv, 5)

# Finale Convolution
self.final1 = nn.Conv2d(filters[0], n_class, 1)
self.final2 = nn.Conv2d(filters[0], n_class, 1)
self.final3 = nn.Conv2d(filters[0], n_class, 1)
self.final4 = nn.Conv2d(filters[0], n_class, 1)

def construct(self, inputs):
x00 = self.conv00(inputs) # channel = filters[0]
x10 = self.conv10(self.maxpool(x00)) # channel = filters[1]
x20 = self.conv20(self.maxpool(x10)) # channel = filters[2]
x30 = self.conv30(self.maxpool(x20)) # channel = filters[3]
x40 = self.conv40(self.maxpool(x30)) # channel = filters[4]

x01 = self.up_concat01(x10, x00) # channel = filters[0]
x11 = self.up_concat11(x20, x10) # channel = filters[1]
x21 = self.up_concat21(x30, x20) # channel = filters[2]
x31 = self.up_concat31(x40, x30) # channel = filters[3]

x02 = self.up_concat02(x11, x00, x01) # channel = filters[0]
x12 = self.up_concat12(x21, x10, x11) # channel = filters[1]
x22 = self.up_concat22(x31, x20, x21) # channel = filters[2]

x03 = self.up_concat03(x12, x00, x01, x02) # channel = filters[0]
x13 = self.up_concat13(x22, x10, x11, x12) # channel = filters[1]

x04 = self.up_concat04(x13, x00, x01, x02, x03) # channel = filters[0]

final1 = self.final1(x01)
final2 = self.final1(x02)
final3 = self.final1(x03)
final4 = self.final1(x04)

final = (final1 + final2 + final3 + final4) / 4.0

if self.use_ds:
return final
return final4


class UNet(nn.Cell):
"""
Simple UNet with skip connection
"""
def __init__(self, in_channel, n_class=2, feature_scale=2, use_deconv=True, use_bn=True):
super(UNet, self).__init__()
self.in_channel = in_channel
self.n_class = n_class
self.feature_scale = feature_scale
self.use_deconv = use_deconv
self.use_bn = use_bn

filters = [64, 128, 256, 512, 1024]
filters = [int(x / self.feature_scale) for x in filters]

# Down Sample
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same")
self.conv0 = UnetConv2d(self.in_channel, filters[0], self.use_bn)
self.conv1 = UnetConv2d(filters[0], filters[1], self.use_bn)
self.conv2 = UnetConv2d(filters[1], filters[2], self.use_bn)
self.conv3 = UnetConv2d(filters[2], filters[3], self.use_bn)
self.conv4 = UnetConv2d(filters[3], filters[4], self.use_bn)

# Up Sample
self.up_concat1 = UnetUp(filters[1], filters[0], self.use_deconv, 2)
self.up_concat2 = UnetUp(filters[2], filters[1], self.use_deconv, 2)
self.up_concat3 = UnetUp(filters[3], filters[2], self.use_deconv, 2)
self.up_concat4 = UnetUp(filters[4], filters[3], self.use_deconv, 2)

# Finale Convolution
self.final = nn.Conv2d(filters[0], n_class, 1)

def construct(self, inputs):
x0 = self.conv0(inputs) # channel = filters[0]
x1 = self.conv1(self.maxpool(x0)) # channel = filters[1]
x2 = self.conv2(self.maxpool(x1)) # channel = filters[2]
x3 = self.conv3(self.maxpool(x2)) # channel = filters[3]
x4 = self.conv4(self.maxpool(x3)) # channel = filters[4]

up4 = self.up_concat4(x4, x3)
up3 = self.up_concat3(up4, x2)
up2 = self.up_concat2(up3, x1)
up1 = self.up_concat1(up2, x0)

final = self.final(up1)

return final

+ 81
- 0
model_zoo/official/cv/unet/src/unet_nested/unet_parts.py View File

@@ -0,0 +1,81 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

""" Parts of the U-Net-PlusPlus model """

import mindspore.nn as nn
import mindspore.ops.functional as F
import mindspore.ops.operations as P


def conv_bn_relu(in_channel, out_channel, use_bn=True, kernel_size=3, stride=1, pad_mode="same", activation='relu'):
output = []
output.append(nn.Conv2d(in_channel, out_channel, kernel_size, stride, pad_mode=pad_mode))
if use_bn:
output.append(nn.BatchNorm2d(out_channel))
if activation:
output.append(nn.get_activation(activation))
return nn.SequentialCell(output)


class UnetConv2d(nn.Cell):
"""
Convolution block in Unet, usually double conv.
"""
def __init__(self, in_channel, out_channel, use_bn=True, num_layer=2, kernel_size=3, stride=1, padding='same'):
super(UnetConv2d, self).__init__()
self.num_layer = num_layer
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.in_channel = in_channel
self.out_channel = out_channel

convs = []
for _ in range(num_layer):
convs.append(conv_bn_relu(in_channel, out_channel, use_bn, kernel_size, stride, padding, "relu"))
in_channel = out_channel

self.convs = nn.SequentialCell(convs)

def construct(self, inputs):
x = self.convs(inputs)
return x


class UnetUp(nn.Cell):
"""
Upsampling high_feature with factor=2 and concat with low feature
"""
def __init__(self, in_channel, out_channel, use_deconv, n_concat=2):
super(UnetUp, self).__init__()
self.conv = UnetConv2d(in_channel + (n_concat - 2) * out_channel, out_channel, False)
self.concat = P.Concat(axis=1)
self.use_deconv = use_deconv
if use_deconv:
self.up_conv = nn.Conv2dTranspose(in_channel, out_channel, kernel_size=2, stride=2, pad_mode="same")
else:
self.up_conv = nn.Conv2d(in_channel, out_channel, 1)

def construct(self, high_feature, *low_feature):
if self.use_deconv:
output = self.up_conv(high_feature)
else:
_, _, h, w = F.shape(high_feature)
output = P.ResizeBilinear((h * 2, w * 2))(high_feature)
output = self.up_conv(output)
for feature in low_feature:
output = self.concat((output, feature))
return self.conv(output)

+ 4
- 0
model_zoo/official/cv/unet/src/utils.py View File

@@ -15,6 +15,7 @@
import time import time
import numpy as np import numpy as np
from PIL import Image
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
@@ -53,3 +54,6 @@ class StepLossTimeMonitor(Callback):
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
# TEST # TEST
print("step: %s, loss is %s, fps is %s" % (cur_step_in_epoch, loss, step_fps), flush=True) print("step: %s, loss is %s, fps is %s" % (cur_step_in_epoch, loss, step_fps), flush=True)
def mask_to_image(mask):
return Image.fromarray((mask * 255).astype(np.uint8))

+ 13
- 3
model_zoo/official/cv/unet/train.py View File

@@ -26,7 +26,8 @@ from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net


from src.unet import UNet
from src.unet_medical import UNetMedical
from src.unet_nested import NestedUNet, UNet
from src.data_loader import create_dataset from src.data_loader import create_dataset
from src.loss import CrossEntropyWithLogits from src.loss import CrossEntropyWithLogits
from src.utils import StepLossTimeMonitor from src.utils import StepLossTimeMonitor
@@ -53,14 +54,23 @@ def train_net(data_dir,
context.set_auto_parallel_context(parallel_mode=parallel_mode, context.set_auto_parallel_context(parallel_mode=parallel_mode,
device_num=group_size, device_num=group_size,
gradients_mean=False) gradients_mean=False)
net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])

if cfg['model'] == 'unet_medical':
net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
elif cfg['model'] == 'unet_nested':
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
elif cfg['model'] == 'unet_simple':
net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
else:
raise ValueError("Unsupported model: {}".format(cfg['model']))


if cfg['resume']: if cfg['resume']:
param_dict = load_checkpoint(cfg['resume_ckpt']) param_dict = load_checkpoint(cfg['resume_ckpt'])
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)


criterion = CrossEntropyWithLogits() criterion = CrossEntropyWithLogits()
train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute)
train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute, cfg["crop"],
cfg['img_size'])
train_data_size = train_dataset.get_dataset_size() train_data_size = train_dataset.get_dataset_size()
print("dataset length is:", train_data_size) print("dataset length is:", train_data_size)
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size, ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,


Loading…
Cancel
Save