|
|
|
@@ -19,7 +19,6 @@ |
|
|
|
- [How to use](#how-to-use)
|
|
|
|
- [Inference](#inference)
|
|
|
|
- [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model)
|
|
|
|
- [Transfer Learning](#transfer-learning)
|
|
|
|
- [Description of Random Situation](#description-of-random-situation)
|
|
|
|
- [ModelZoo Homepage](#modelzoo-homepage)
|
|
|
|
|
|
|
|
@@ -130,6 +129,8 @@ Parameters for both training and evaluation can be set in config.py |
|
|
|
'weight_decay': 0.0005, # weight decay value
|
|
|
|
'loss_scale': 1024.0, # loss scale
|
|
|
|
'FixedLossScaleManager': 1024.0, # fix loss scale
|
|
|
|
'resume': False, # whether training with pretrain model
|
|
|
|
'resume_ckpt': './', # pretrain model path
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
@@ -260,8 +261,42 @@ If you need to use the trained model to perform inference on multiple hardware p |
|
|
|
print("============== Cross valid dice coeff is:", dice_score)
|
|
|
|
```
|
|
|
|
|
|
|
|
### Transfer Learning
|
|
|
|
To be added.
|
|
|
|
### Continue Training on the Pretrained Model
|
|
|
|
|
|
|
|
- running on Ascend
|
|
|
|
|
|
|
|
```
|
|
|
|
# Define model
|
|
|
|
net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
|
|
|
|
# Continue training if set 'resume' to be True
|
|
|
|
if cfg['resume']:
|
|
|
|
param_dict = load_checkpoint(cfg['resume_ckpt'])
|
|
|
|
load_param_into_net(net, param_dict)
|
|
|
|
|
|
|
|
# Load dataset
|
|
|
|
train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute)
|
|
|
|
train_data_size = train_dataset.get_dataset_size()
|
|
|
|
|
|
|
|
optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr, weight_decay=cfg['weight_decay'],
|
|
|
|
loss_scale=cfg['loss_scale'])
|
|
|
|
criterion = CrossEntropyWithLogits()
|
|
|
|
loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(cfg['FixedLossScaleManager'], False)
|
|
|
|
|
|
|
|
model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer, amp_level="O3")
|
|
|
|
|
|
|
|
|
|
|
|
# Set callbacks
|
|
|
|
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,
|
|
|
|
keep_checkpoint_max=cfg['keep_checkpoint_max'])
|
|
|
|
ckpoint_cb = ModelCheckpoint(prefix='ckpt_unet_medical_adam',
|
|
|
|
directory='./ckpt_{}/'.format(device_id),
|
|
|
|
config=ckpt_config)
|
|
|
|
|
|
|
|
print("============== Starting Training ==============")
|
|
|
|
model.train(1, train_dataset, callbacks=[StepLossTimeMonitor(batch_size=batch_size), ckpoint_cb],
|
|
|
|
dataset_sink_mode=False)
|
|
|
|
print("============== End Training ==============")
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
# [Description of Random Situation](#contents)
|
|
|
|
|