You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 7.0 kB

5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. #################train tinydarknet example on cifar10########################
  17. python train.py
  18. """
  19. import os
  20. import argparse
  21. from mindspore import Tensor
  22. from mindspore import context
  23. from mindspore.communication.management import init, get_rank
  24. from mindspore.nn.optim.momentum import Momentum
  25. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
  26. from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
  27. from mindspore.train.model import Model
  28. from mindspore.context import ParallelMode
  29. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  30. from mindspore.common import set_seed
  31. from src.config import imagenet_cfg
  32. from src.dataset import create_dataset_imagenet
  33. from src.tinydarknet import TinyDarkNet
  34. from src.CrossEntropySmooth import CrossEntropySmooth
  35. set_seed(1)
  36. def lr_steps_imagenet(_cfg, steps_per_epoch):
  37. """lr step for imagenet"""
  38. from src.lr_scheduler.warmup_step_lr import warmup_step_lr
  39. from src.lr_scheduler.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr
  40. if _cfg.lr_scheduler == 'exponential':
  41. _lr = warmup_step_lr(_cfg.lr_init,
  42. _cfg.lr_epochs,
  43. steps_per_epoch,
  44. _cfg.warmup_epochs,
  45. _cfg.epoch_size,
  46. gamma=_cfg.lr_gamma,
  47. )
  48. elif _cfg.lr_scheduler == 'cosine_annealing':
  49. _lr = warmup_cosine_annealing_lr(_cfg.lr_init,
  50. steps_per_epoch,
  51. _cfg.warmup_epochs,
  52. _cfg.epoch_size,
  53. _cfg.T_max,
  54. _cfg.eta_min)
  55. else:
  56. raise NotImplementedError(_cfg.lr_scheduler)
  57. return _lr
  58. if __name__ == '__main__':
  59. parser = argparse.ArgumentParser(description='Classification')
  60. parser.add_argument('--dataset_name', type=str, default='imagenet', choices=['imagenet', 'cifar10'],
  61. help='dataset name.')
  62. parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: None)')
  63. args_opt = parser.parse_args()
  64. if args_opt.dataset_name == "imagenet":
  65. cfg = imagenet_cfg
  66. else:
  67. raise ValueError("Unsupported dataset.")
  68. # set context
  69. device_target = cfg.device_target
  70. context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
  71. device_num = int(os.environ.get("DEVICE_NUM", 1))
  72. rank = 0
  73. if device_target == "Ascend":
  74. context.set_context(device_id=args_opt.device_id)
  75. if device_num > 1:
  76. context.reset_auto_parallel_context()
  77. context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
  78. gradients_mean=True)
  79. init()
  80. rank = get_rank()
  81. else:
  82. raise ValueError("Unsupported platform.")
  83. if args_opt.dataset_name == "imagenet":
  84. dataset = create_dataset_imagenet(cfg.data_path, 1)
  85. else:
  86. raise ValueError("Unsupported dataset.")
  87. batch_num = dataset.get_dataset_size()
  88. net = TinyDarkNet(num_classes=cfg.num_classes)
  89. # Continue training if set pre_trained to be True
  90. if cfg.pre_trained:
  91. param_dict = load_checkpoint(cfg.checkpoint_path)
  92. load_param_into_net(net, param_dict)
  93. loss_scale_manager = None
  94. if args_opt.dataset_name == 'imagenet':
  95. lr = lr_steps_imagenet(cfg, batch_num)
  96. def get_param_groups(network):
  97. """ get param groups """
  98. decay_params = []
  99. no_decay_params = []
  100. for x in network.trainable_params():
  101. parameter_name = x.name
  102. if parameter_name.endswith('.bias'):
  103. # all bias not using weight decay
  104. no_decay_params.append(x)
  105. elif parameter_name.endswith('.gamma'):
  106. # bn weight bias not using weight decay, be carefully for now x not include BN
  107. no_decay_params.append(x)
  108. elif parameter_name.endswith('.beta'):
  109. # bn weight bias not using weight decay, be carefully for now x not include BN
  110. no_decay_params.append(x)
  111. else:
  112. decay_params.append(x)
  113. return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
  114. if cfg.is_dynamic_loss_scale:
  115. cfg.loss_scale = 1
  116. opt = Momentum(params=get_param_groups(net),
  117. learning_rate=Tensor(lr),
  118. momentum=cfg.momentum,
  119. weight_decay=cfg.weight_decay,
  120. loss_scale=cfg.loss_scale)
  121. if not cfg.use_label_smooth:
  122. cfg.label_smooth_factor = 0.0
  123. loss = CrossEntropySmooth(sparse=True, reduction="mean",
  124. smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
  125. if cfg.is_dynamic_loss_scale:
  126. loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
  127. else:
  128. loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False)
  129. model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
  130. amp_level="O3", loss_scale_manager=loss_scale_manager)
  131. config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 50, keep_checkpoint_max=cfg.keep_checkpoint_max)
  132. time_cb = TimeMonitor(data_size=batch_num)
  133. ckpt_save_dir = "./ckpt_" + str(rank) + "/"
  134. ckpoint_cb = ModelCheckpoint(prefix="train_tinydarknet_" + args_opt.dataset_name, directory=ckpt_save_dir,
  135. config=config_ck)
  136. loss_cb = LossMonitor()
  137. model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
  138. print("train success")