| @@ -24,7 +24,7 @@ from mindspore import context, Model | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.nn.loss.loss import _Loss | from mindspore.nn.loss.loss import _Loss | ||||
| from src.data_loader import create_dataset | |||||
| from src.data_loader import create_dataset, create_cell_nuclei_dataset | |||||
| from src.unet_medical import UNetMedical | from src.unet_medical import UNetMedical | ||||
| from src.unet_nested import NestedUNet, UNet | from src.unet_nested import NestedUNet, UNet | ||||
| from src.config import cfg_unet | from src.config import cfg_unet | ||||
| @@ -59,6 +59,7 @@ class dice_coeff(nn.Metric): | |||||
| self.clear() | self.clear() | ||||
| def clear(self): | def clear(self): | ||||
| self._dice_coeff_sum = 0 | self._dice_coeff_sum = 0 | ||||
| self._iou_sum = 0 | |||||
| self._samples_num = 0 | self._samples_num = 0 | ||||
| def update(self, *inputs): | def update(self, *inputs): | ||||
| @@ -77,13 +78,15 @@ class dice_coeff(nn.Metric): | |||||
| union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten()) | union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten()) | ||||
| single_dice_coeff = 2*float(inter)/float(union+1e-6) | single_dice_coeff = 2*float(inter)/float(union+1e-6) | ||||
| print("single dice coeff is:", single_dice_coeff) | |||||
| single_iou = single_dice_coeff / (2 - single_dice_coeff) | |||||
| print("single dice coeff is: {}, IOU is: {}".format(single_dice_coeff, single_iou)) | |||||
| self._dice_coeff_sum += single_dice_coeff | self._dice_coeff_sum += single_dice_coeff | ||||
| self._iou_sum += single_iou | |||||
| def eval(self): | def eval(self): | ||||
| if self._samples_num == 0: | if self._samples_num == 0: | ||||
| raise RuntimeError('Total samples num must not be 0.') | raise RuntimeError('Total samples num must not be 0.') | ||||
| return self._dice_coeff_sum / float(self._samples_num) | |||||
| return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num)) | |||||
| def test_net(data_dir, | def test_net(data_dir, | ||||
| @@ -93,7 +96,8 @@ def test_net(data_dir, | |||||
| if cfg['model'] == 'unet_medical': | if cfg['model'] == 'unet_medical': | ||||
| net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) | net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) | ||||
| elif cfg['model'] == 'unet_nested': | elif cfg['model'] == 'unet_nested': | ||||
| net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) | |||||
| net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'], | |||||
| use_bn=cfg['use_bn'], use_ds=False) | |||||
| elif cfg['model'] == 'unet_simple': | elif cfg['model'] == 'unet_simple': | ||||
| net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) | net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) | ||||
| else: | else: | ||||
| @@ -102,13 +106,17 @@ def test_net(data_dir, | |||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| criterion = CrossEntropyWithLogits() | criterion = CrossEntropyWithLogits() | ||||
| _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, | |||||
| do_crop=cfg['crop'], img_size=cfg['img_size']) | |||||
| if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei": | |||||
| valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False, split=0.8) | |||||
| else: | |||||
| _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, | |||||
| do_crop=cfg['crop'], img_size=cfg['img_size']) | |||||
| model = Model(net, loss_fn=criterion, metrics={"dice_coeff": dice_coeff()}) | model = Model(net, loss_fn=criterion, metrics={"dice_coeff": dice_coeff()}) | ||||
| print("============== Starting Evaluating ============") | print("============== Starting Evaluating ============") | ||||
| dice_score = model.eval(valid_dataset, dataset_sink_mode=False) | |||||
| print("============== Cross valid dice coeff is:", dice_score) | |||||
| eval_score = model.eval(valid_dataset, dataset_sink_mode=False)["dice_coeff"] | |||||
| print("============== Cross valid dice coeff is:", eval_score[0]) | |||||
| print("============== Cross valid IOU is:", eval_score[1]) | |||||
| def get_args(): | def get_args(): | ||||
| @@ -42,7 +42,8 @@ if __name__ == "__main__": | |||||
| if cfg['model'] == 'unet_medical': | if cfg['model'] == 'unet_medical': | ||||
| net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) | net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) | ||||
| elif cfg['model'] == 'unet_nested': | elif cfg['model'] == 'unet_nested': | ||||
| net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) | |||||
| net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'], | |||||
| use_bn=cfg['use_bn'], use_ds=False) | |||||
| elif cfg['model'] == 'unet_simple': | elif cfg['model'] == 'unet_simple': | ||||
| net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) | net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) | ||||
| else: | else: | ||||
| @@ -50,6 +50,34 @@ cfg_unet_nested = { | |||||
| 'weight_decay': 0.0005, | 'weight_decay': 0.0005, | ||||
| 'loss_scale': 1024.0, | 'loss_scale': 1024.0, | ||||
| 'FixedLossScaleManager': 1024.0, | 'FixedLossScaleManager': 1024.0, | ||||
| 'use_bn': True, | |||||
| 'use_ds': True, | |||||
| 'use_deconv': True, | |||||
| 'resume': False, | |||||
| 'resume_ckpt': './', | |||||
| } | |||||
| cfg_unet_nested_cell = { | |||||
| 'model': 'unet_nested', | |||||
| 'dataset': 'Cell_nuclei', | |||||
| 'crop': None, | |||||
| 'img_size': [96, 96], | |||||
| 'lr': 3e-4, | |||||
| 'epochs': 200, | |||||
| 'distribute_epochs': 1600, | |||||
| 'batchsize': 16, | |||||
| 'cross_valid_ind': 1, | |||||
| 'num_classes': 2, | |||||
| 'num_channels': 3, | |||||
| 'keep_checkpoint_max': 10, | |||||
| 'weight_decay': 0.0005, | |||||
| 'loss_scale': 1024.0, | |||||
| 'FixedLossScaleManager': 1024.0, | |||||
| 'use_bn': True, | |||||
| 'use_ds': True, | |||||
| 'use_deconv': True, | |||||
| 'resume': False, | 'resume': False, | ||||
| 'resume_ckpt': './', | 'resume_ckpt': './', | ||||
| @@ -15,6 +15,7 @@ | |||||
| import os | import os | ||||
| from collections import deque | from collections import deque | ||||
| import cv2 | |||||
| import numpy as np | import numpy as np | ||||
| from PIL import Image, ImageSequence | from PIL import Image, ImageSequence | ||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| @@ -23,7 +24,6 @@ from mindspore.dataset.vision.utils import Inter | |||||
| from mindspore.communication.management import get_rank, get_group_size | from mindspore.communication.management import get_rank, get_group_size | ||||
| def _load_multipage_tiff(path): | def _load_multipage_tiff(path): | ||||
| """Load tiff images containing many images in the channel dimension""" | """Load tiff images containing many images in the channel dimension""" | ||||
| return np.array([np.array(p) for p in ImageSequence.Iterator(Image.open(path))]) | return np.array([np.array(p) for p in ImageSequence.Iterator(Image.open(path))]) | ||||
| @@ -164,3 +164,100 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro | |||||
| valid_ds = valid_ds.batch(batch_size=1, drop_remainder=True) | valid_ds = valid_ds.batch(batch_size=1, drop_remainder=True) | ||||
| return train_ds, valid_ds | return train_ds, valid_ds | ||||
| class CellNucleiDataset: | |||||
| """ | |||||
| Cell nuclei dataset preprocess class. | |||||
| """ | |||||
| def __init__(self, data_dir, repeat, is_train=False, split=0.8): | |||||
| self.data_dir = data_dir | |||||
| self.img_ids = sorted(next(os.walk(self.data_dir))[1]) | |||||
| self.train_ids = self.img_ids[:int(len(self.img_ids) * split)] * repeat | |||||
| np.random.shuffle(self.train_ids) | |||||
| self.val_ids = self.img_ids[int(len(self.img_ids) * split):] | |||||
| self.is_train = is_train | |||||
| self._preprocess_dataset() | |||||
| def _preprocess_dataset(self): | |||||
| for img_id in self.img_ids: | |||||
| path = os.path.join(self.data_dir, img_id) | |||||
| if (not os.path.exists(os.path.join(path, "image.png"))) or \ | |||||
| (not os.path.exists(os.path.join(path, "mask.png"))): | |||||
| img = cv2.imread(os.path.join(path, "images", img_id + ".png")) | |||||
| if len(img.shape) == 2: | |||||
| img = np.expand_dims(img, axis=-1) | |||||
| img = np.concatenate([img, img, img], axis=-1) | |||||
| mask = [] | |||||
| for mask_file in next(os.walk(os.path.join(path, "masks")))[2]: | |||||
| mask_ = cv2.imread(os.path.join(path, "masks", mask_file), cv2.IMREAD_GRAYSCALE) | |||||
| mask.append(mask_) | |||||
| mask = np.max(mask, axis=0) | |||||
| cv2.imwrite(os.path.join(path, "image.png"), img) | |||||
| cv2.imwrite(os.path.join(path, "mask.png"), mask) | |||||
| def _read_img_mask(self, img_id): | |||||
| path = os.path.join(self.data_dir, img_id) | |||||
| img = cv2.imread(os.path.join(path, "image.png")) | |||||
| mask = cv2.imread(os.path.join(path, "mask.png"), cv2.IMREAD_GRAYSCALE) | |||||
| return img, mask | |||||
| def __getitem__(self, index): | |||||
| if self.is_train: | |||||
| return self._read_img_mask(self.train_ids[index]) | |||||
| return self._read_img_mask(self.val_ids[index]) | |||||
| @property | |||||
| def column_names(self): | |||||
| column_names = ['image', 'mask'] | |||||
| return column_names | |||||
| def __len__(self): | |||||
| if self.is_train: | |||||
| return len(self.train_ids) | |||||
| return len(self.val_ids) | |||||
| def preprocess_img_mask(img, mask, img_size, augment=False): | |||||
| """ | |||||
| Preprocess for cell nuclei dataset. | |||||
| Random crop and flip images and masks when augment is True. | |||||
| """ | |||||
| if augment: | |||||
| img_size_w = int(np.random.randint(img_size[0], img_size[0] * 1.5, 1)) | |||||
| img_size_h = int(np.random.randint(img_size[1], img_size[1] * 1.5, 1)) | |||||
| img = cv2.resize(img, (img_size_w, img_size_h)) | |||||
| mask = cv2.resize(mask, (img_size_w, img_size_h)) | |||||
| dw = int(np.random.randint(0, img_size_w - img_size[0] + 1, 1)) | |||||
| dh = int(np.random.randint(0, img_size_h - img_size[1] + 1, 1)) | |||||
| img = img[dh:dh+img_size[1], dw:dw+img_size[0], :] | |||||
| mask = mask[dh:dh+img_size[1], dw:dw+img_size[0]] | |||||
| if np.random.random() > 0.5: | |||||
| flip_code = int(np.random.randint(-1, 2, 1)) | |||||
| img = cv2.flip(img, flip_code) | |||||
| mask = cv2.flip(mask, flip_code) | |||||
| else: | |||||
| img = cv2.resize(img, img_size) | |||||
| mask = cv2.resize(mask, img_size) | |||||
| img = (img.astype(np.float32) - 127.5) / 127.5 | |||||
| img = img.transpose(2, 0, 1) | |||||
| mask = mask.astype(np.float32) / 255 | |||||
| mask = (mask > 0.5).astype(np.int) | |||||
| mask = (np.arange(2) == mask[..., None]).astype(int) | |||||
| mask = mask.transpose(2, 0, 1).astype(np.float32) | |||||
| return img, mask | |||||
| def create_cell_nuclei_dataset(data_dir, img_size, repeat, batch_size, is_train=False, augment=False, | |||||
| split=0.8, rank=0, group_size=1, python_multiprocessing=True, num_parallel_workers=8): | |||||
| """ | |||||
| Get generator dataset for cell nuclei dataset. | |||||
| """ | |||||
| cell_dataset = CellNucleiDataset(data_dir, repeat, is_train, split) | |||||
| sampler = ds.DistributedSampler(group_size, rank, shuffle=is_train) | |||||
| dataset = ds.GeneratorDataset(cell_dataset, cell_dataset.column_names, sampler=sampler) | |||||
| compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, tuple(img_size), augment and is_train)) | |||||
| dataset = dataset.map(operations=compose_map_func, input_columns=cell_dataset.column_names, | |||||
| output_columns=cell_dataset.column_names, column_order=cell_dataset.column_names, | |||||
| python_multiprocessing=python_multiprocessing, | |||||
| num_parallel_workers=num_parallel_workers) | |||||
| dataset = dataset.batch(batch_size, drop_remainder=is_train) | |||||
| dataset = dataset.repeat(1) | |||||
| return dataset | |||||
| @@ -36,3 +36,15 @@ class CrossEntropyWithLogits(_Loss): | |||||
| loss = self.reduce_mean( | loss = self.reduce_mean( | ||||
| self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, 2)), self.reshape_fn(label, (-1, 2)))) | self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, 2)), self.reshape_fn(label, (-1, 2)))) | ||||
| return self.get_loss(loss) | return self.get_loss(loss) | ||||
| class MultiCrossEntropyWithLogits(nn.Cell): | |||||
| def __init__(self): | |||||
| super(MultiCrossEntropyWithLogits, self).__init__() | |||||
| self.loss = CrossEntropyWithLogits() | |||||
| self.squeeze = F.Squeeze() | |||||
| def construct(self, logits, label): | |||||
| total_loss = 0 | |||||
| for i in range(len(logits)): | |||||
| total_loss += self.loss(self.squeeze(logits[i:i+1]), label) | |||||
| return total_loss | |||||
| @@ -16,6 +16,7 @@ | |||||
| # Model of UnetPlusPlus | # Model of UnetPlusPlus | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| import mindspore.ops as P | |||||
| from .unet_parts import UnetConv2d, UnetUp | from .unet_parts import UnetConv2d, UnetUp | ||||
| @@ -63,6 +64,7 @@ class NestedUNet(nn.Cell): | |||||
| self.final2 = nn.Conv2d(filters[0], n_class, 1) | self.final2 = nn.Conv2d(filters[0], n_class, 1) | ||||
| self.final3 = nn.Conv2d(filters[0], n_class, 1) | self.final3 = nn.Conv2d(filters[0], n_class, 1) | ||||
| self.final4 = nn.Conv2d(filters[0], n_class, 1) | self.final4 = nn.Conv2d(filters[0], n_class, 1) | ||||
| self.stack = P.Stack(axis=0) | |||||
| def construct(self, inputs): | def construct(self, inputs): | ||||
| x00 = self.conv00(inputs) # channel = filters[0] | x00 = self.conv00(inputs) # channel = filters[0] | ||||
| @@ -86,13 +88,12 @@ class NestedUNet(nn.Cell): | |||||
| x04 = self.up_concat04(x13, x00, x01, x02, x03) # channel = filters[0] | x04 = self.up_concat04(x13, x00, x01, x02, x03) # channel = filters[0] | ||||
| final1 = self.final1(x01) | final1 = self.final1(x01) | ||||
| final2 = self.final1(x02) | |||||
| final3 = self.final1(x03) | |||||
| final4 = self.final1(x04) | |||||
| final = (final1 + final2 + final3 + final4) / 4.0 | |||||
| final2 = self.final2(x02) | |||||
| final3 = self.final3(x03) | |||||
| final4 = self.final4(x04) | |||||
| if self.use_ds: | if self.use_ds: | ||||
| final = self.stack((final1, final2, final3, final4)) | |||||
| return final | return final | ||||
| return final4 | return final4 | ||||
| @@ -51,9 +51,23 @@ class StepLossTimeMonitor(Callback): | |||||
| if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): | if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): | ||||
| raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( | raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( | ||||
| cb_params.cur_epoch_num, cur_step_in_epoch)) | cb_params.cur_epoch_num, cur_step_in_epoch)) | ||||
| self.losses.append(loss) | |||||
| if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: | if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: | ||||
| # TEST | # TEST | ||||
| print("step: %s, loss is %s, fps is %s" % (cur_step_in_epoch, loss, step_fps), flush=True) | print("step: %s, loss is %s, fps is %s" % (cur_step_in_epoch, loss, step_fps), flush=True) | ||||
| def epoch_begin(self, run_context): | |||||
| self.epoch_start = time.time() | |||||
| self.losses = [] | |||||
| def epoch_end(self, run_context): | |||||
| cb_params = run_context.original_args() | |||||
| epoch_cost = time.time() - self.epoch_start | |||||
| step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | |||||
| step_fps = self.batch_size * 1.0 * step_in_epoch / epoch_cost | |||||
| print("epoch: {:3d}, avg loss:{:.4f}, total cost: {:.3f} s, per step fps:{:5.3f}".format( | |||||
| cb_params.cur_epoch_num, np.mean(self.losses), epoch_cost, step_fps), flush=True) | |||||
| def mask_to_image(mask): | def mask_to_image(mask): | ||||
| return Image.fromarray((mask * 255).astype(np.uint8)) | return Image.fromarray((mask * 255).astype(np.uint8)) | ||||
| @@ -21,15 +21,15 @@ import ast | |||||
| import mindspore | import mindspore | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import Model, context | from mindspore import Model, context | ||||
| from mindspore.communication.management import init, get_group_size | |||||
| from mindspore.communication.management import init, get_group_size, get_rank | |||||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint | from mindspore.train.callback import CheckpointConfig, ModelCheckpoint | ||||
| from mindspore.context import ParallelMode | from mindspore.context import ParallelMode | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from src.unet_medical import UNetMedical | from src.unet_medical import UNetMedical | ||||
| from src.unet_nested import NestedUNet, UNet | from src.unet_nested import NestedUNet, UNet | ||||
| from src.data_loader import create_dataset | |||||
| from src.loss import CrossEntropyWithLogits | |||||
| from src.data_loader import create_dataset, create_cell_nuclei_dataset | |||||
| from src.loss import CrossEntropyWithLogits, MultiCrossEntropyWithLogits | |||||
| from src.utils import StepLossTimeMonitor | from src.utils import StepLossTimeMonitor | ||||
| from src.config import cfg_unet | from src.config import cfg_unet | ||||
| @@ -46,10 +46,12 @@ def train_net(data_dir, | |||||
| run_distribute=False, | run_distribute=False, | ||||
| cfg=None): | cfg=None): | ||||
| rank = 0 | |||||
| group_size = 1 | |||||
| if run_distribute: | if run_distribute: | ||||
| init() | init() | ||||
| group_size = get_group_size() | group_size = get_group_size() | ||||
| rank = get_rank() | |||||
| parallel_mode = ParallelMode.DATA_PARALLEL | parallel_mode = ParallelMode.DATA_PARALLEL | ||||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, | context.set_auto_parallel_context(parallel_mode=parallel_mode, | ||||
| device_num=group_size, | device_num=group_size, | ||||
| @@ -58,7 +60,8 @@ def train_net(data_dir, | |||||
| if cfg['model'] == 'unet_medical': | if cfg['model'] == 'unet_medical': | ||||
| net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) | net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) | ||||
| elif cfg['model'] == 'unet_nested': | elif cfg['model'] == 'unet_nested': | ||||
| net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) | |||||
| net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'], | |||||
| use_bn=cfg['use_bn'], use_ds=cfg['use_ds']) | |||||
| elif cfg['model'] == 'unet_simple': | elif cfg['model'] == 'unet_simple': | ||||
| net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) | net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) | ||||
| else: | else: | ||||
| @@ -68,14 +71,28 @@ def train_net(data_dir, | |||||
| param_dict = load_checkpoint(cfg['resume_ckpt']) | param_dict = load_checkpoint(cfg['resume_ckpt']) | ||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| criterion = CrossEntropyWithLogits() | |||||
| train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute, cfg["crop"], | |||||
| cfg['img_size']) | |||||
| if 'use_ds' in cfg and cfg['use_ds']: | |||||
| criterion = MultiCrossEntropyWithLogits() | |||||
| else: | |||||
| criterion = CrossEntropyWithLogits() | |||||
| if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei": | |||||
| repeat = 10 | |||||
| dataset_sink_mode = True | |||||
| per_print_times = 0 | |||||
| train_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], repeat, batch_size, | |||||
| is_train=True, augment=True, split=0.8, rank=rank, | |||||
| group_size=group_size) | |||||
| else: | |||||
| repeat = epochs | |||||
| dataset_sink_mode = False | |||||
| per_print_times = 1 | |||||
| train_dataset, _ = create_dataset(data_dir, repeat, batch_size, True, cross_valid_ind, run_distribute, | |||||
| cfg["crop"], cfg['img_size']) | |||||
| train_data_size = train_dataset.get_dataset_size() | train_data_size = train_dataset.get_dataset_size() | ||||
| print("dataset length is:", train_data_size) | print("dataset length is:", train_data_size) | ||||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size, | ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size, | ||||
| keep_checkpoint_max=cfg['keep_checkpoint_max']) | keep_checkpoint_max=cfg['keep_checkpoint_max']) | ||||
| ckpoint_cb = ModelCheckpoint(prefix='ckpt_unet_medical_adam', | |||||
| ckpoint_cb = ModelCheckpoint(prefix='ckpt_{}_adam'.format(cfg['model']), | |||||
| directory='./ckpt_{}/'.format(device_id), | directory='./ckpt_{}/'.format(device_id), | ||||
| config=ckpt_config) | config=ckpt_config) | ||||
| @@ -87,13 +104,11 @@ def train_net(data_dir, | |||||
| model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer, amp_level="O3") | model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer, amp_level="O3") | ||||
| print("============== Starting Training ==============") | print("============== Starting Training ==============") | ||||
| model.train(1, train_dataset, callbacks=[StepLossTimeMonitor(batch_size=batch_size), ckpoint_cb], | |||||
| dataset_sink_mode=False) | |||||
| callbacks = [StepLossTimeMonitor(batch_size=batch_size, per_print_times=per_print_times), ckpoint_cb] | |||||
| model.train(int(epochs / repeat), train_dataset, callbacks=callbacks, dataset_sink_mode=dataset_sink_mode) | |||||
| print("============== End Training ==============") | print("============== End Training ==============") | ||||
| def get_args(): | def get_args(): | ||||
| parser = argparse.ArgumentParser(description='Train the UNet on images and target masks', | parser = argparse.ArgumentParser(description='Train the UNet on images and target masks', | ||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||