You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 8.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import os
  16. import argparse
  17. import logging
  18. import ast
  19. import mindspore
  20. import mindspore.nn as nn
  21. from mindspore import Model, context
  22. from mindspore.communication.management import init, get_group_size, get_rank
  23. from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
  24. from mindspore.context import ParallelMode
  25. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  26. from src.unet_medical import UNetMedical
  27. from src.unet_nested import NestedUNet, UNet
  28. from src.data_loader import create_dataset, create_cell_nuclei_dataset
  29. from src.loss import CrossEntropyWithLogits, MultiCrossEntropyWithLogits
  30. from src.utils import StepLossTimeMonitor, UnetEval, TempLoss, apply_eval, filter_checkpoint_parameter_by_list, dice_coeff
  31. from src.config import cfg_unet
  32. from src.eval_callback import EvalCallBack
  33. device_id = int(os.getenv('DEVICE_ID'))
  34. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
  35. mindspore.set_seed(1)
  36. def train_net(args_opt,
  37. cross_valid_ind=1,
  38. epochs=400,
  39. batch_size=16,
  40. lr=0.0001,
  41. cfg=None):
  42. rank = 0
  43. group_size = 1
  44. data_dir = args_opt.data_url
  45. run_distribute = args_opt.run_distribute
  46. if run_distribute:
  47. init()
  48. group_size = get_group_size()
  49. rank = get_rank()
  50. parallel_mode = ParallelMode.DATA_PARALLEL
  51. context.set_auto_parallel_context(parallel_mode=parallel_mode,
  52. device_num=group_size,
  53. gradients_mean=False)
  54. need_slice = False
  55. if cfg['model'] == 'unet_medical':
  56. net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
  57. elif cfg['model'] == 'unet_nested':
  58. net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'],
  59. use_bn=cfg['use_bn'], use_ds=cfg['use_ds'])
  60. need_slice = cfg['use_ds']
  61. elif cfg['model'] == 'unet_simple':
  62. net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
  63. else:
  64. raise ValueError("Unsupported model: {}".format(cfg['model']))
  65. if cfg['resume']:
  66. param_dict = load_checkpoint(cfg['resume_ckpt'])
  67. if cfg['transfer_training']:
  68. filter_checkpoint_parameter_by_list(param_dict, cfg['filter_weight'])
  69. load_param_into_net(net, param_dict)
  70. if 'use_ds' in cfg and cfg['use_ds']:
  71. criterion = MultiCrossEntropyWithLogits()
  72. else:
  73. criterion = CrossEntropyWithLogits()
  74. if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei":
  75. repeat = cfg['repeat']
  76. dataset_sink_mode = True
  77. per_print_times = 0
  78. train_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], repeat, batch_size,
  79. is_train=True, augment=True, split=0.8, rank=rank,
  80. group_size=group_size)
  81. valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False,
  82. eval_resize=cfg["eval_resize"], split=0.8,
  83. python_multiprocessing=False)
  84. else:
  85. repeat = cfg['repeat']
  86. dataset_sink_mode = False
  87. per_print_times = 1
  88. train_dataset, valid_dataset = create_dataset(data_dir, repeat, batch_size, True, cross_valid_ind,
  89. run_distribute, cfg["crop"], cfg['img_size'])
  90. train_data_size = train_dataset.get_dataset_size()
  91. print("dataset length is:", train_data_size)
  92. ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,
  93. keep_checkpoint_max=cfg['keep_checkpoint_max'])
  94. ckpoint_cb = ModelCheckpoint(prefix='ckpt_{}_adam'.format(cfg['model']),
  95. directory='./ckpt_{}/'.format(device_id),
  96. config=ckpt_config)
  97. optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr, weight_decay=cfg['weight_decay'],
  98. loss_scale=cfg['loss_scale'])
  99. loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(cfg['FixedLossScaleManager'], False)
  100. model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer, amp_level="O3")
  101. print("============== Starting Training ==============")
  102. callbacks = [StepLossTimeMonitor(batch_size=batch_size, per_print_times=per_print_times), ckpoint_cb]
  103. if args_opt.run_eval:
  104. eval_model = Model(UnetEval(net, need_slice=need_slice), loss_fn=TempLoss(),
  105. metrics={"dice_coeff": dice_coeff(cfg_unet, False)})
  106. eval_param_dict = {"model": eval_model, "dataset": valid_dataset, "metrics_name": args_opt.eval_metrics}
  107. eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval,
  108. eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True,
  109. ckpt_directory='./ckpt_{}/'.format(device_id), besk_ckpt_name="best.ckpt",
  110. metrics_name=args_opt.eval_metrics)
  111. callbacks.append(eval_cb)
  112. model.train(int(epochs / repeat), train_dataset, callbacks=callbacks, dataset_sink_mode=dataset_sink_mode)
  113. print("============== End Training ==============")
  114. def get_args():
  115. parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
  116. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  117. parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/',
  118. help='data directory')
  119. parser.add_argument('-t', '--run_distribute', type=ast.literal_eval,
  120. default=False, help='Run distribute, default: false.')
  121. parser.add_argument("--run_eval", type=ast.literal_eval, default=False,
  122. help="Run evaluation when training, default is False.")
  123. parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True,
  124. help="Save best checkpoint when run_eval is True, default is True.")
  125. parser.add_argument("--eval_start_epoch", type=int, default=0,
  126. help="Evaluation start epoch when run_eval is True, default is 0.")
  127. parser.add_argument("--eval_interval", type=int, default=1,
  128. help="Evaluation interval when run_eval is True, default is 1.")
  129. parser.add_argument("--eval_metrics", type=str, default="dice_coeff", choices=("dice_coeff", "iou"),
  130. help="Evaluation metrics when run_eval is True, support [dice_coeff, iou], "
  131. "default is dice_coeff.")
  132. return parser.parse_args()
  133. if __name__ == '__main__':
  134. logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
  135. args = get_args()
  136. print("Training setting:", args)
  137. epoch_size = cfg_unet['epochs'] if not args.run_distribute else cfg_unet['distribute_epochs']
  138. train_net(args_opt=args,
  139. cross_valid_ind=cfg_unet['cross_valid_ind'],
  140. epochs=epoch_size,
  141. batch_size=cfg_unet['batchsize'],
  142. lr=cfg_unet['lr'],
  143. cfg=cfg_unet)