From: @c_34 Reviewed-by: @wuxuejian Signed-off-by: @wuxuejiantags/v1.2.0-rc1
| @@ -25,7 +25,8 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from mindspore.nn.loss.loss import _Loss | from mindspore.nn.loss.loss import _Loss | ||||
| from src.data_loader import create_dataset | from src.data_loader import create_dataset | ||||
| from src.unet import UNet | |||||
| from src.unet_medical import UNetMedical | |||||
| from src.unet_nested import NestedUNet, UNet | |||||
| from src.config import cfg_unet | from src.config import cfg_unet | ||||
| from scipy.special import softmax | from scipy.special import softmax | ||||
| @@ -34,8 +35,6 @@ device_id = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) | ||||
| class CrossEntropyWithLogits(_Loss): | class CrossEntropyWithLogits(_Loss): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(CrossEntropyWithLogits, self).__init__() | super(CrossEntropyWithLogits, self).__init__() | ||||
| @@ -64,10 +63,11 @@ class dice_coeff(nn.Metric): | |||||
| def update(self, *inputs): | def update(self, *inputs): | ||||
| if len(inputs) != 2: | if len(inputs) != 2: | ||||
| raise ValueError('Mean dice coeffcient need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) | |||||
| raise ValueError('Mean dice coefficient need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) | |||||
| y_pred = self._convert_data(inputs[0]) | y_pred = self._convert_data(inputs[0]) | ||||
| y = self._convert_data(inputs[1]) | y = self._convert_data(inputs[1]) | ||||
| self._samples_num += y.shape[0] | self._samples_num += y.shape[0] | ||||
| y_pred = y_pred.transpose(0, 2, 3, 1) | y_pred = y_pred.transpose(0, 2, 3, 1) | ||||
| y = y.transpose(0, 2, 3, 1) | y = y.transpose(0, 2, 3, 1) | ||||
| @@ -90,13 +90,20 @@ def test_net(data_dir, | |||||
| ckpt_path, | ckpt_path, | ||||
| cross_valid_ind=1, | cross_valid_ind=1, | ||||
| cfg=None): | cfg=None): | ||||
| net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) | |||||
| if cfg['model'] == 'unet_medical': | |||||
| net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) | |||||
| elif cfg['model'] == 'unet_nested': | |||||
| net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) | |||||
| elif cfg['model'] == 'unet_simple': | |||||
| net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) | |||||
| else: | |||||
| raise ValueError("Unsupported model: {}".format(cfg['model'])) | |||||
| param_dict = load_checkpoint(ckpt_path) | param_dict = load_checkpoint(ckpt_path) | ||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| criterion = CrossEntropyWithLogits() | criterion = CrossEntropyWithLogits() | ||||
| _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False) | |||||
| _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, | |||||
| do_crop=cfg['crop'], img_size=cfg['img_size']) | |||||
| model = Model(net, loss_fn=criterion, metrics={"dice_coeff": dice_coeff()}) | model = Model(net, loss_fn=criterion, metrics={"dice_coeff": dice_coeff()}) | ||||
| print("============== Starting Evaluating ============") | print("============== Starting Evaluating ============") | ||||
| @@ -18,7 +18,8 @@ import numpy as np | |||||
| from mindspore import Tensor, export, load_checkpoint, load_param_into_net, context | from mindspore import Tensor, export, load_checkpoint, load_param_into_net, context | ||||
| from src.unet.unet_model import UNet | |||||
| from src.unet_medical.unet_model import UNetMedical | |||||
| from src.unet_nested import NestedUNet, UNet | |||||
| from src.config import cfg_unet as cfg | from src.config import cfg_unet as cfg | ||||
| parser = argparse.ArgumentParser(description='unet export') | parser = argparse.ArgumentParser(description='unet export') | ||||
| @@ -38,7 +39,14 @@ if args.device_target == "Ascend": | |||||
| context.set_context(device_id=args.device_id) | context.set_context(device_id=args.device_id) | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| net = UNet(n_channels=cfg["num_channels"], n_classes=cfg["num_classes"]) | |||||
| if cfg['model'] == 'unet_medical': | |||||
| net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) | |||||
| elif cfg['model'] == 'unet_nested': | |||||
| net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) | |||||
| elif cfg['model'] == 'unet_simple': | |||||
| net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) | |||||
| else: | |||||
| raise ValueError("Unsupported model: {}".format(cfg['model'])) | |||||
| # return a parameter dict for model | # return a parameter dict for model | ||||
| param_dict = load_checkpoint(args.ckpt_file) | param_dict = load_checkpoint(args.ckpt_file) | ||||
| # load the parameter into net | # load the parameter into net | ||||
| @@ -13,7 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """hub config.""" | """hub config.""" | ||||
| from src.unet import UNet | |||||
| from src.unet_medical import UNet | |||||
| def create_network(name, *args, **kwargs): | def create_network(name, *args, **kwargs): | ||||
| if name == "unet2d": | if name == "unet2d": | ||||
| @@ -14,15 +14,14 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| echo "==============================================================================================================" | |||||
| echo "Please run the script as: " | |||||
| echo "bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]" | |||||
| echo "for example: bash run_distribute_train.sh /absolute/path/to/RANK_TABLE_FILE /absolute/path/to/data" | |||||
| echo "==============================================================================================================" | |||||
| if [ $# != 2 ] | if [ $# != 2 ] | ||||
| then | then | ||||
| echo "==============================================================================================================" | |||||
| echo "Usage: bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]" | echo "Usage: bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]" | ||||
| echo "Please run the script as: " | |||||
| echo "bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET]" | |||||
| echo "for example: bash run_distribute_train.sh /absolute/path/to/RANK_TABLE_FILE /absolute/path/to/data" | |||||
| echo "==============================================================================================================" | |||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| @@ -14,11 +14,14 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| echo "==============================================================================================================" | |||||
| echo "Please run the script as: " | |||||
| echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]" | |||||
| echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/" | |||||
| echo "==============================================================================================================" | |||||
| if [ $# != 2 ] | |||||
| then | |||||
| echo "==============================================================================================================" | |||||
| echo "Please run the script as: " | |||||
| echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]" | |||||
| echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/" | |||||
| echo "==============================================================================================================" | |||||
| fi | |||||
| export DEVICE_ID=0 | export DEVICE_ID=0 | ||||
| python eval.py --data_url=$1 --ckpt_path=$2 > eval.log 2>&1 & | python eval.py --data_url=$1 --ckpt_path=$2 > eval.log 2>&1 & | ||||
| @@ -14,11 +14,14 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| echo "==============================================================================================================" | |||||
| echo "Please run the script as: " | |||||
| echo "bash scripts/run_standalone_train.sh [DATASET]" | |||||
| echo "for example: bash run_standalone_train.sh /path/to/data/" | |||||
| echo "==============================================================================================================" | |||||
| if [ $# != 1 ] | |||||
| then | |||||
| echo "==============================================================================================================" | |||||
| echo "Please run the script as: " | |||||
| echo "bash scripts/run_standalone_train.sh [DATASET]" | |||||
| echo "for example: bash run_standalone_train.sh /path/to/data/" | |||||
| echo "==============================================================================================================" | |||||
| fi | |||||
| export DEVICE_ID=0 | export DEVICE_ID=0 | ||||
| python train.py --data_url=$1 > train.log 2>&1 & | python train.py --data_url=$1 > train.log 2>&1 & | ||||
| @@ -13,7 +13,10 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| cfg_unet = { | |||||
| cfg_unet_medical = { | |||||
| 'model': 'unet_medical', | |||||
| 'crop': [388 / 572, 388 / 572], | |||||
| 'img_size': [572, 572], | |||||
| 'lr': 0.0001, | 'lr': 0.0001, | ||||
| 'epochs': 400, | 'epochs': 400, | ||||
| 'distribute_epochs': 1600, | 'distribute_epochs': 1600, | ||||
| @@ -30,3 +33,47 @@ cfg_unet = { | |||||
| 'resume': False, | 'resume': False, | ||||
| 'resume_ckpt': './', | 'resume_ckpt': './', | ||||
| } | } | ||||
| cfg_unet_nested = { | |||||
| 'model': 'unet_nested', | |||||
| 'crop': None, | |||||
| 'img_size': [576, 576], | |||||
| 'lr': 0.0001, | |||||
| 'epochs': 400, | |||||
| 'distribute_epochs': 1600, | |||||
| 'batchsize': 16, | |||||
| 'cross_valid_ind': 1, | |||||
| 'num_classes': 2, | |||||
| 'num_channels': 1, | |||||
| 'keep_checkpoint_max': 10, | |||||
| 'weight_decay': 0.0005, | |||||
| 'loss_scale': 1024.0, | |||||
| 'FixedLossScaleManager': 1024.0, | |||||
| 'resume': False, | |||||
| 'resume_ckpt': './', | |||||
| } | |||||
| cfg_unet_simple = { | |||||
| 'model': 'unet_simple', | |||||
| 'crop': None, | |||||
| 'img_size': [576, 576], | |||||
| 'lr': 0.0001, | |||||
| 'epochs': 400, | |||||
| 'distribute_epochs': 1600, | |||||
| 'batchsize': 16, | |||||
| 'cross_valid_ind': 1, | |||||
| 'num_classes': 2, | |||||
| 'num_channels': 1, | |||||
| 'keep_checkpoint_max': 10, | |||||
| 'weight_decay': 0.0005, | |||||
| 'loss_scale': 1024.0, | |||||
| 'FixedLossScaleManager': 1024.0, | |||||
| 'resume': False, | |||||
| 'resume_ckpt': './', | |||||
| } | |||||
| cfg_unet = cfg_unet_medical | |||||
| @@ -82,7 +82,8 @@ def train_data_augmentation(img, mask): | |||||
| return img, mask | return img, mask | ||||
| def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cross_val_ind=1, run_distribute=False): | |||||
| def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cross_val_ind=1, run_distribute=False, | |||||
| do_crop=None, img_size=None): | |||||
| images = _load_multipage_tiff(os.path.join(data_dir, 'train-volume.tif')) | images = _load_multipage_tiff(os.path.join(data_dir, 'train-volume.tif')) | ||||
| masks = _load_multipage_tiff(os.path.join(data_dir, 'train-labels.tif')) | masks = _load_multipage_tiff(os.path.join(data_dir, 'train-labels.tif')) | ||||
| @@ -121,8 +122,12 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro | |||||
| ds_valid_images = ds.NumpySlicesDataset(data=valid_image_data, sampler=None, shuffle=False) | ds_valid_images = ds.NumpySlicesDataset(data=valid_image_data, sampler=None, shuffle=False) | ||||
| ds_valid_masks = ds.NumpySlicesDataset(data=valid_mask_data, sampler=None, shuffle=False) | ds_valid_masks = ds.NumpySlicesDataset(data=valid_mask_data, sampler=None, shuffle=False) | ||||
| c_resize_op = c_vision.Resize(size=(388, 388), interpolation=Inter.BILINEAR) | |||||
| c_pad = c_vision.Pad(padding=92) | |||||
| if do_crop: | |||||
| resize_size = [int(img_size[x] * do_crop[x]) for x in range(len(img_size))] | |||||
| else: | |||||
| resize_size = img_size | |||||
| c_resize_op = c_vision.Resize(size=(resize_size[0], resize_size[1]), interpolation=Inter.BILINEAR) | |||||
| c_pad = c_vision.Pad(padding=(img_size[0] - resize_size[0]) // 2) | |||||
| c_rescale_image = c_vision.Rescale(1.0/127.5, -1) | c_rescale_image = c_vision.Rescale(1.0/127.5, -1) | ||||
| c_rescale_mask = c_vision.Rescale(1.0/255.0, 0) | c_rescale_mask = c_vision.Rescale(1.0/255.0, 0) | ||||
| @@ -136,12 +141,13 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro | |||||
| train_ds = train_ds.project(columns=["image", "mask"]) | train_ds = train_ds.project(columns=["image", "mask"]) | ||||
| if augment: | if augment: | ||||
| augment_process = train_data_augmentation | augment_process = train_data_augmentation | ||||
| c_resize_op = c_vision.Resize(size=(572, 572), interpolation=Inter.BILINEAR) | |||||
| c_resize_op = c_vision.Resize(size=(img_size[0], img_size[1]), interpolation=Inter.BILINEAR) | |||||
| train_ds = train_ds.map(input_columns=["image", "mask"], operations=augment_process) | train_ds = train_ds.map(input_columns=["image", "mask"], operations=augment_process) | ||||
| train_ds = train_ds.map(input_columns="image", operations=c_resize_op) | train_ds = train_ds.map(input_columns="image", operations=c_resize_op) | ||||
| train_ds = train_ds.map(input_columns="mask", operations=c_resize_op) | train_ds = train_ds.map(input_columns="mask", operations=c_resize_op) | ||||
| train_ds = train_ds.map(input_columns="mask", operations=c_center_crop) | |||||
| if do_crop: | |||||
| train_ds = train_ds.map(input_columns="mask", operations=c_center_crop) | |||||
| post_process = data_post_process | post_process = data_post_process | ||||
| train_ds = train_ds.map(input_columns=["image", "mask"], operations=post_process) | train_ds = train_ds.map(input_columns=["image", "mask"], operations=post_process) | ||||
| train_ds = train_ds.shuffle(repeat*24) | train_ds = train_ds.shuffle(repeat*24) | ||||
| @@ -151,7 +157,8 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro | |||||
| valid_mask_ds = ds_valid_masks.map(input_columns="mask", operations=c_trans_normalize_mask) | valid_mask_ds = ds_valid_masks.map(input_columns="mask", operations=c_trans_normalize_mask) | ||||
| valid_ds = ds.zip((valid_image_ds, valid_mask_ds)) | valid_ds = ds.zip((valid_image_ds, valid_mask_ds)) | ||||
| valid_ds = valid_ds.project(columns=["image", "mask"]) | valid_ds = valid_ds.project(columns=["image", "mask"]) | ||||
| valid_ds = valid_ds.map(input_columns="mask", operations=c_center_crop) | |||||
| if do_crop: | |||||
| valid_ds = valid_ds.map(input_columns="mask", operations=c_center_crop) | |||||
| post_process = data_post_process | post_process = data_post_process | ||||
| valid_ds = valid_ds.map(input_columns=["image", "mask"], operations=post_process) | valid_ds = valid_ds.map(input_columns=["image", "mask"], operations=post_process) | ||||
| valid_ds = valid_ds.batch(batch_size=1, drop_remainder=True) | valid_ds = valid_ds.batch(batch_size=1, drop_remainder=True) | ||||
| @@ -13,4 +13,4 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| from .unet_model import UNet | |||||
| from .unet_model import UNetMedical | |||||
| @@ -13,12 +13,12 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| from src.unet.unet_parts import DoubleConv, Down, Up1, Up2, Up3, Up4, OutConv | |||||
| from src.unet_medical.unet_parts import DoubleConv, Down, Up1, Up2, Up3, Up4, OutConv | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| class UNet(nn.Cell): | |||||
| class UNetMedical(nn.Cell): | |||||
| def __init__(self, n_channels, n_classes): | def __init__(self, n_channels, n_classes): | ||||
| super(UNet, self).__init__() | |||||
| super(UNetMedical, self).__init__() | |||||
| self.n_channels = n_channels | self.n_channels = n_channels | ||||
| self.n_classes = n_classes | self.n_classes = n_classes | ||||
| self.inc = DoubleConv(n_channels, 64) | self.inc = DoubleConv(n_channels, 64) | ||||
| @@ -0,0 +1,16 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # less required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from .unet_model import NestedUNet, UNet | |||||
| @@ -0,0 +1,146 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # less required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| # Model of UnetPlusPlus | |||||
| import mindspore.nn as nn | |||||
| from .unet_parts import UnetConv2d, UnetUp | |||||
| class NestedUNet(nn.Cell): | |||||
| """ | |||||
| Nested unet | |||||
| """ | |||||
| def __init__(self, in_channel, n_class=2, feature_scale=2, use_deconv=True, use_bn=True, use_ds=True): | |||||
| super(NestedUNet, self).__init__() | |||||
| self.in_channel = in_channel | |||||
| self.n_class = n_class | |||||
| self.feature_scale = feature_scale | |||||
| self.use_deconv = use_deconv | |||||
| self.use_bn = use_bn | |||||
| self.use_ds = use_ds | |||||
| filters = [64, 128, 256, 512, 1024] | |||||
| filters = [int(x / self.feature_scale) for x in filters] | |||||
| # Down Sample | |||||
| self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same") | |||||
| self.conv00 = UnetConv2d(self.in_channel, filters[0], self.use_bn) | |||||
| self.conv10 = UnetConv2d(filters[0], filters[1], self.use_bn) | |||||
| self.conv20 = UnetConv2d(filters[1], filters[2], self.use_bn) | |||||
| self.conv30 = UnetConv2d(filters[2], filters[3], self.use_bn) | |||||
| self.conv40 = UnetConv2d(filters[3], filters[4], self.use_bn) | |||||
| # Up Sample | |||||
| self.up_concat01 = UnetUp(filters[1], filters[0], self.use_deconv, 2) | |||||
| self.up_concat11 = UnetUp(filters[2], filters[1], self.use_deconv, 2) | |||||
| self.up_concat21 = UnetUp(filters[3], filters[2], self.use_deconv, 2) | |||||
| self.up_concat31 = UnetUp(filters[4], filters[3], self.use_deconv, 2) | |||||
| self.up_concat02 = UnetUp(filters[1], filters[0], self.use_deconv, 3) | |||||
| self.up_concat12 = UnetUp(filters[2], filters[1], self.use_deconv, 3) | |||||
| self.up_concat22 = UnetUp(filters[3], filters[2], self.use_deconv, 3) | |||||
| self.up_concat03 = UnetUp(filters[1], filters[0], self.use_deconv, 4) | |||||
| self.up_concat13 = UnetUp(filters[2], filters[1], self.use_deconv, 4) | |||||
| self.up_concat04 = UnetUp(filters[1], filters[0], self.use_deconv, 5) | |||||
| # Finale Convolution | |||||
| self.final1 = nn.Conv2d(filters[0], n_class, 1) | |||||
| self.final2 = nn.Conv2d(filters[0], n_class, 1) | |||||
| self.final3 = nn.Conv2d(filters[0], n_class, 1) | |||||
| self.final4 = nn.Conv2d(filters[0], n_class, 1) | |||||
| def construct(self, inputs): | |||||
| x00 = self.conv00(inputs) # channel = filters[0] | |||||
| x10 = self.conv10(self.maxpool(x00)) # channel = filters[1] | |||||
| x20 = self.conv20(self.maxpool(x10)) # channel = filters[2] | |||||
| x30 = self.conv30(self.maxpool(x20)) # channel = filters[3] | |||||
| x40 = self.conv40(self.maxpool(x30)) # channel = filters[4] | |||||
| x01 = self.up_concat01(x10, x00) # channel = filters[0] | |||||
| x11 = self.up_concat11(x20, x10) # channel = filters[1] | |||||
| x21 = self.up_concat21(x30, x20) # channel = filters[2] | |||||
| x31 = self.up_concat31(x40, x30) # channel = filters[3] | |||||
| x02 = self.up_concat02(x11, x00, x01) # channel = filters[0] | |||||
| x12 = self.up_concat12(x21, x10, x11) # channel = filters[1] | |||||
| x22 = self.up_concat22(x31, x20, x21) # channel = filters[2] | |||||
| x03 = self.up_concat03(x12, x00, x01, x02) # channel = filters[0] | |||||
| x13 = self.up_concat13(x22, x10, x11, x12) # channel = filters[1] | |||||
| x04 = self.up_concat04(x13, x00, x01, x02, x03) # channel = filters[0] | |||||
| final1 = self.final1(x01) | |||||
| final2 = self.final1(x02) | |||||
| final3 = self.final1(x03) | |||||
| final4 = self.final1(x04) | |||||
| final = (final1 + final2 + final3 + final4) / 4.0 | |||||
| if self.use_ds: | |||||
| return final | |||||
| return final4 | |||||
| class UNet(nn.Cell): | |||||
| """ | |||||
| Simple UNet with skip connection | |||||
| """ | |||||
| def __init__(self, in_channel, n_class=2, feature_scale=2, use_deconv=True, use_bn=True): | |||||
| super(UNet, self).__init__() | |||||
| self.in_channel = in_channel | |||||
| self.n_class = n_class | |||||
| self.feature_scale = feature_scale | |||||
| self.use_deconv = use_deconv | |||||
| self.use_bn = use_bn | |||||
| filters = [64, 128, 256, 512, 1024] | |||||
| filters = [int(x / self.feature_scale) for x in filters] | |||||
| # Down Sample | |||||
| self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same") | |||||
| self.conv0 = UnetConv2d(self.in_channel, filters[0], self.use_bn) | |||||
| self.conv1 = UnetConv2d(filters[0], filters[1], self.use_bn) | |||||
| self.conv2 = UnetConv2d(filters[1], filters[2], self.use_bn) | |||||
| self.conv3 = UnetConv2d(filters[2], filters[3], self.use_bn) | |||||
| self.conv4 = UnetConv2d(filters[3], filters[4], self.use_bn) | |||||
| # Up Sample | |||||
| self.up_concat1 = UnetUp(filters[1], filters[0], self.use_deconv, 2) | |||||
| self.up_concat2 = UnetUp(filters[2], filters[1], self.use_deconv, 2) | |||||
| self.up_concat3 = UnetUp(filters[3], filters[2], self.use_deconv, 2) | |||||
| self.up_concat4 = UnetUp(filters[4], filters[3], self.use_deconv, 2) | |||||
| # Finale Convolution | |||||
| self.final = nn.Conv2d(filters[0], n_class, 1) | |||||
| def construct(self, inputs): | |||||
| x0 = self.conv0(inputs) # channel = filters[0] | |||||
| x1 = self.conv1(self.maxpool(x0)) # channel = filters[1] | |||||
| x2 = self.conv2(self.maxpool(x1)) # channel = filters[2] | |||||
| x3 = self.conv3(self.maxpool(x2)) # channel = filters[3] | |||||
| x4 = self.conv4(self.maxpool(x3)) # channel = filters[4] | |||||
| up4 = self.up_concat4(x4, x3) | |||||
| up3 = self.up_concat3(up4, x2) | |||||
| up2 = self.up_concat2(up3, x1) | |||||
| up1 = self.up_concat1(up2, x0) | |||||
| final = self.final(up1) | |||||
| return final | |||||
| @@ -0,0 +1,81 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # less required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ Parts of the U-Net-PlusPlus model """ | |||||
| import mindspore.nn as nn | |||||
| import mindspore.ops.functional as F | |||||
| import mindspore.ops.operations as P | |||||
| def conv_bn_relu(in_channel, out_channel, use_bn=True, kernel_size=3, stride=1, pad_mode="same", activation='relu'): | |||||
| output = [] | |||||
| output.append(nn.Conv2d(in_channel, out_channel, kernel_size, stride, pad_mode=pad_mode)) | |||||
| if use_bn: | |||||
| output.append(nn.BatchNorm2d(out_channel)) | |||||
| if activation: | |||||
| output.append(nn.get_activation(activation)) | |||||
| return nn.SequentialCell(output) | |||||
| class UnetConv2d(nn.Cell): | |||||
| """ | |||||
| Convolution block in Unet, usually double conv. | |||||
| """ | |||||
| def __init__(self, in_channel, out_channel, use_bn=True, num_layer=2, kernel_size=3, stride=1, padding='same'): | |||||
| super(UnetConv2d, self).__init__() | |||||
| self.num_layer = num_layer | |||||
| self.kernel_size = kernel_size | |||||
| self.stride = stride | |||||
| self.padding = padding | |||||
| self.in_channel = in_channel | |||||
| self.out_channel = out_channel | |||||
| convs = [] | |||||
| for _ in range(num_layer): | |||||
| convs.append(conv_bn_relu(in_channel, out_channel, use_bn, kernel_size, stride, padding, "relu")) | |||||
| in_channel = out_channel | |||||
| self.convs = nn.SequentialCell(convs) | |||||
| def construct(self, inputs): | |||||
| x = self.convs(inputs) | |||||
| return x | |||||
| class UnetUp(nn.Cell): | |||||
| """ | |||||
| Upsampling high_feature with factor=2 and concat with low feature | |||||
| """ | |||||
| def __init__(self, in_channel, out_channel, use_deconv, n_concat=2): | |||||
| super(UnetUp, self).__init__() | |||||
| self.conv = UnetConv2d(in_channel + (n_concat - 2) * out_channel, out_channel, False) | |||||
| self.concat = P.Concat(axis=1) | |||||
| self.use_deconv = use_deconv | |||||
| if use_deconv: | |||||
| self.up_conv = nn.Conv2dTranspose(in_channel, out_channel, kernel_size=2, stride=2, pad_mode="same") | |||||
| else: | |||||
| self.up_conv = nn.Conv2d(in_channel, out_channel, 1) | |||||
| def construct(self, high_feature, *low_feature): | |||||
| if self.use_deconv: | |||||
| output = self.up_conv(high_feature) | |||||
| else: | |||||
| _, _, h, w = F.shape(high_feature) | |||||
| output = P.ResizeBilinear((h * 2, w * 2))(high_feature) | |||||
| output = self.up_conv(output) | |||||
| for feature in low_feature: | |||||
| output = self.concat((output, feature)) | |||||
| return self.conv(output) | |||||
| @@ -15,6 +15,7 @@ | |||||
| import time | import time | ||||
| import numpy as np | import numpy as np | ||||
| from PIL import Image | |||||
| from mindspore.train.callback import Callback | from mindspore.train.callback import Callback | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| @@ -53,3 +54,6 @@ class StepLossTimeMonitor(Callback): | |||||
| if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: | if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: | ||||
| # TEST | # TEST | ||||
| print("step: %s, loss is %s, fps is %s" % (cur_step_in_epoch, loss, step_fps), flush=True) | print("step: %s, loss is %s, fps is %s" % (cur_step_in_epoch, loss, step_fps), flush=True) | ||||
| def mask_to_image(mask): | |||||
| return Image.fromarray((mask * 255).astype(np.uint8)) | |||||
| @@ -26,7 +26,8 @@ from mindspore.train.callback import CheckpointConfig, ModelCheckpoint | |||||
| from mindspore.context import ParallelMode | from mindspore.context import ParallelMode | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from src.unet import UNet | |||||
| from src.unet_medical import UNetMedical | |||||
| from src.unet_nested import NestedUNet, UNet | |||||
| from src.data_loader import create_dataset | from src.data_loader import create_dataset | ||||
| from src.loss import CrossEntropyWithLogits | from src.loss import CrossEntropyWithLogits | ||||
| from src.utils import StepLossTimeMonitor | from src.utils import StepLossTimeMonitor | ||||
| @@ -53,14 +54,23 @@ def train_net(data_dir, | |||||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, | context.set_auto_parallel_context(parallel_mode=parallel_mode, | ||||
| device_num=group_size, | device_num=group_size, | ||||
| gradients_mean=False) | gradients_mean=False) | ||||
| net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) | |||||
| if cfg['model'] == 'unet_medical': | |||||
| net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) | |||||
| elif cfg['model'] == 'unet_nested': | |||||
| net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) | |||||
| elif cfg['model'] == 'unet_simple': | |||||
| net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) | |||||
| else: | |||||
| raise ValueError("Unsupported model: {}".format(cfg['model'])) | |||||
| if cfg['resume']: | if cfg['resume']: | ||||
| param_dict = load_checkpoint(cfg['resume_ckpt']) | param_dict = load_checkpoint(cfg['resume_ckpt']) | ||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| criterion = CrossEntropyWithLogits() | criterion = CrossEntropyWithLogits() | ||||
| train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute) | |||||
| train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute, cfg["crop"], | |||||
| cfg['img_size']) | |||||
| train_data_size = train_dataset.get_dataset_size() | train_data_size = train_dataset.get_dataset_size() | ||||
| print("dataset length is:", train_data_size) | print("dataset length is:", train_data_size) | ||||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size, | ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size, | ||||