You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 7.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """train imagenet."""
  16. import argparse
  17. import math
  18. import os
  19. import random
  20. import numpy as np
  21. import mindspore
  22. from mindspore import Tensor, context
  23. from mindspore.communication.management import get_group_size, get_rank, init
  24. from mindspore.nn import SGD, RMSProp
  25. from mindspore.train.callback import (CheckpointConfig, LossMonitor,
  26. ModelCheckpoint, TimeMonitor)
  27. from mindspore.train.loss_scale_manager import FixedLossScaleManager
  28. from mindspore.train.model import Model, ParallelMode
  29. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  30. from src.config import efficientnet_b0_config_gpu as cfg
  31. from src.dataset import create_dataset
  32. from src.efficientnet import efficientnet_b0
  33. from src.loss import LabelSmoothingCrossEntropy
  34. mindspore.common.set_seed(cfg.random_seed)
  35. random.seed(cfg.random_seed)
  36. np.random.seed(cfg.random_seed)
  37. def get_lr(base_lr, total_epochs, steps_per_epoch, decay_steps=1,
  38. decay_rate=0.9, warmup_steps=0., warmup_lr_init=0., global_epoch=0):
  39. lr_each_step = []
  40. total_steps = steps_per_epoch * total_epochs
  41. global_steps = steps_per_epoch * global_epoch
  42. self_warmup_delta = ((base_lr - warmup_lr_init) /
  43. warmup_steps) if warmup_steps > 0 else 0
  44. self_decay_rate = decay_rate if decay_rate < 1 else 1 / decay_rate
  45. for i in range(total_steps):
  46. steps = math.floor(i / steps_per_epoch)
  47. cond = 1 if (steps < warmup_steps) else 0
  48. warmup_lr = warmup_lr_init + steps * self_warmup_delta
  49. decay_nums = math.floor(steps / decay_steps)
  50. decay_rate = math.pow(self_decay_rate, decay_nums)
  51. decay_lr = base_lr * decay_rate
  52. lr = cond * warmup_lr + (1 - cond) * decay_lr
  53. lr_each_step.append(lr)
  54. lr_each_step = lr_each_step[global_steps:]
  55. lr_each_step = np.array(lr_each_step).astype(np.float32)
  56. return lr_each_step
  57. def get_outdir(path, *paths, inc=False):
  58. outdir = os.path.join(path, *paths)
  59. if not os.path.exists(outdir):
  60. os.makedirs(outdir)
  61. elif inc:
  62. count = 1
  63. outdir_inc = outdir + '-' + str(count)
  64. while os.path.exists(outdir_inc):
  65. count = count + 1
  66. outdir_inc = outdir + '-' + str(count)
  67. assert count < 100
  68. outdir = outdir_inc
  69. os.makedirs(outdir)
  70. return outdir
  71. parser = argparse.ArgumentParser(
  72. description='Training configuration', add_help=False)
  73. parser.add_argument('--data_path', type=str, default='/home/dataset/imagenet_jpeg/', metavar='DIR',
  74. help='path to dataset')
  75. parser.add_argument('--distributed', action='store_true', default=False)
  76. parser.add_argument('--GPU', action='store_true', default=False,
  77. help='Use GPU for training (default: False)')
  78. parser.add_argument('--cur_time', type=str,
  79. default='19701010-000000', help='current time')
  80. parser.add_argument('--resume', default='', type=str, metavar='PATH',
  81. help='Resume full model and optimizer state from checkpoint (default: none)')
  82. def main():
  83. args, _ = parser.parse_known_args()
  84. rank_id, rank_size = 0, 1
  85. context.set_context(mode=context.GRAPH_MODE)
  86. if args.distributed:
  87. if args.GPU:
  88. init("nccl")
  89. context.set_context(device_target='GPU')
  90. else:
  91. raise ValueError("Only supported GPU training.")
  92. context.reset_auto_parallel_context()
  93. rank_id = get_rank()
  94. rank_size = get_group_size()
  95. context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
  96. gradients_mean=True, device_num=rank_size)
  97. else:
  98. if args.GPU:
  99. context.set_context(device_target='GPU')
  100. else:
  101. raise ValueError("Only supported GPU training.")
  102. net = efficientnet_b0(num_classes=cfg.num_classes,
  103. drop_rate=cfg.drop,
  104. drop_connect_rate=cfg.drop_connect,
  105. global_pool=cfg.gp,
  106. bn_tf=cfg.bn_tf,
  107. )
  108. train_data_url = args.data_path
  109. train_dataset = create_dataset(
  110. cfg.batch_size, train_data_url, workers=cfg.workers, distributed=args.distributed)
  111. batches_per_epoch = train_dataset.get_dataset_size()
  112. loss_cb = LossMonitor(per_print_times=batches_per_epoch)
  113. loss = LabelSmoothingCrossEntropy(smooth_factor=cfg.smoothing)
  114. time_cb = TimeMonitor(data_size=batches_per_epoch)
  115. loss_scale_manager = FixedLossScaleManager(
  116. cfg.loss_scale, drop_overflow_update=False)
  117. callbacks = [time_cb, loss_cb]
  118. if cfg.save_checkpoint:
  119. config_ck = CheckpointConfig(
  120. save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max)
  121. ckpoint_cb = ModelCheckpoint(
  122. prefix=cfg.model, directory='./ckpt_' + str(rank_id) + '/', config=config_ck)
  123. callbacks += [ckpoint_cb]
  124. lr = Tensor(get_lr(base_lr=cfg.lr, total_epochs=cfg.epochs, steps_per_epoch=batches_per_epoch,
  125. decay_steps=cfg.decay_epochs, decay_rate=cfg.decay_rate,
  126. warmup_steps=cfg.warmup_epochs, warmup_lr_init=cfg.warmup_lr_init,
  127. global_epoch=cfg.resume_start_epoch))
  128. if cfg.opt == 'sgd':
  129. optimizer = SGD(net.trainable_params(), learning_rate=lr, momentum=cfg.momentum,
  130. weight_decay=cfg.weight_decay,
  131. loss_scale=cfg.loss_scale
  132. )
  133. elif cfg.opt == 'rmsprop':
  134. optimizer = RMSProp(net.trainable_params(), learning_rate=lr, decay=0.9, weight_decay=cfg.weight_decay,
  135. momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale
  136. )
  137. loss.add_flags_recursive(fp32=True, fp16=False)
  138. if args.resume:
  139. ckpt = load_checkpoint(args.resume)
  140. load_param_into_net(net, ckpt)
  141. model = Model(net, loss, optimizer,
  142. loss_scale_manager=loss_scale_manager,
  143. amp_level=cfg.amp_level
  144. )
  145. # callbacks = callbacks if is_master else []
  146. if args.resume:
  147. real_epoch = cfg.epochs - cfg.resume_start_epoch
  148. model.train(real_epoch, train_dataset,
  149. callbacks=callbacks, dataset_sink_mode=True)
  150. else:
  151. model.train(cfg.epochs, train_dataset,
  152. callbacks=callbacks, dataset_sink_mode=True)
  153. if __name__ == '__main__':
  154. main()