You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """train imagenet."""
  16. import argparse
  17. import math
  18. import os
  19. import random
  20. import time
  21. import numpy as np
  22. import mindspore
  23. from mindspore import Tensor, context
  24. from mindspore.communication.management import get_group_size, get_rank, init
  25. from mindspore.nn import SGD, RMSProp
  26. from mindspore.train.callback import (CheckpointConfig, LossMonitor,
  27. ModelCheckpoint, TimeMonitor)
  28. from mindspore.train.loss_scale_manager import FixedLossScaleManager
  29. from mindspore.train.model import Model, ParallelMode
  30. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  31. from src.config import efficientnet_b0_config_gpu as cfg
  32. from src.dataset import create_dataset
  33. from src.efficientnet import efficientnet_b0
  34. from src.loss import LabelSmoothingCrossEntropy
  35. mindspore.common.set_seed(cfg.random_seed)
  36. random.seed(cfg.random_seed)
  37. np.random.seed(cfg.random_seed)
  38. def get_lr(base_lr, total_epochs, steps_per_epoch, decay_steps=1,
  39. decay_rate=0.9, warmup_steps=0., warmup_lr_init=0., global_epoch=0):
  40. lr_each_step = []
  41. total_steps = steps_per_epoch * total_epochs
  42. global_steps = steps_per_epoch * global_epoch
  43. self_warmup_delta = ((base_lr - warmup_lr_init) /
  44. warmup_steps) if warmup_steps > 0 else 0
  45. self_decay_rate = decay_rate if decay_rate < 1 else 1 / decay_rate
  46. for i in range(total_steps):
  47. steps = math.floor(i / steps_per_epoch)
  48. cond = 1 if (steps < warmup_steps) else 0
  49. warmup_lr = warmup_lr_init + steps * self_warmup_delta
  50. decay_nums = math.floor(steps / decay_steps)
  51. decay_rate = math.pow(self_decay_rate, decay_nums)
  52. decay_lr = base_lr * decay_rate
  53. lr = cond * warmup_lr + (1 - cond) * decay_lr
  54. lr_each_step.append(lr)
  55. lr_each_step = lr_each_step[global_steps:]
  56. lr_each_step = np.array(lr_each_step).astype(np.float32)
  57. return lr_each_step
  58. def get_outdir(path, *paths, inc=False):
  59. outdir = os.path.join(path, *paths)
  60. if not os.path.exists(outdir):
  61. os.makedirs(outdir)
  62. elif inc:
  63. count = 1
  64. outdir_inc = outdir + '-' + str(count)
  65. while os.path.exists(outdir_inc):
  66. count = count + 1
  67. outdir_inc = outdir + '-' + str(count)
  68. assert count < 100
  69. outdir = outdir_inc
  70. os.makedirs(outdir)
  71. return outdir
  72. parser = argparse.ArgumentParser(
  73. description='Training configuration', add_help=False)
  74. parser.add_argument('--data_path', type=str, default='/home/dataset/imagenet_jpeg/', metavar='DIR',
  75. help='path to dataset')
  76. parser.add_argument('--distributed', action='store_true', default=False)
  77. parser.add_argument('--GPU', action='store_true', default=False,
  78. help='Use GPU for training (default: False)')
  79. parser.add_argument('--cur_time', type=str,
  80. default='19701010-000000', help='current time')
  81. parser.add_argument('--resume', default='', type=str, metavar='PATH',
  82. help='Resume full model and optimizer state from checkpoint (default: none)')
  83. def main():
  84. args, _ = parser.parse_known_args()
  85. devid, rank_id, rank_size = 0, 0, 1
  86. context.set_context(mode=context.GRAPH_MODE)
  87. if args.distributed:
  88. if args.GPU:
  89. init("nccl")
  90. context.set_context(device_target='GPU')
  91. else:
  92. init()
  93. devid = int(os.getenv('DEVICE_ID'))
  94. context.set_context(
  95. device_target='Ascend', device_id=devid, reserve_class_name_in_scope=False)
  96. context.reset_auto_parallel_context()
  97. rank_id = get_rank()
  98. rank_size = get_group_size()
  99. context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
  100. gradients_mean=True, device_num=rank_size)
  101. else:
  102. if args.GPU:
  103. context.set_context(device_target='GPU')
  104. is_master = not args.distributed or (rank_id == 0)
  105. net = efficientnet_b0(num_classes=cfg.num_classes,
  106. drop_rate=cfg.drop,
  107. drop_connect_rate=cfg.drop_connect,
  108. global_pool=cfg.gp,
  109. bn_tf=cfg.bn_tf,
  110. )
  111. cur_time = args.cur_time
  112. output_base = './output'
  113. exp_name = '-'.join([
  114. cur_time,
  115. cfg.model,
  116. str(224)
  117. ])
  118. time.sleep(rank_id)
  119. output_dir = get_outdir(output_base, exp_name)
  120. train_data_url = os.path.join(args.data_path, 'train')
  121. train_dataset = create_dataset(
  122. cfg.batch_size, train_data_url, workers=cfg.workers, distributed=args.distributed)
  123. batches_per_epoch = train_dataset.get_dataset_size()
  124. loss_cb = LossMonitor(per_print_times=batches_per_epoch)
  125. loss = LabelSmoothingCrossEntropy(smooth_factor=cfg.smoothing)
  126. time_cb = TimeMonitor(data_size=batches_per_epoch)
  127. loss_scale_manager = FixedLossScaleManager(
  128. cfg.loss_scale, drop_overflow_update=False)
  129. callbacks = [time_cb, loss_cb]
  130. if cfg.save_checkpoint:
  131. config_ck = CheckpointConfig(
  132. save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max)
  133. ckpoint_cb = ModelCheckpoint(
  134. prefix=cfg.model, directory=output_dir, config=config_ck)
  135. callbacks += [ckpoint_cb]
  136. lr = Tensor(get_lr(base_lr=cfg.lr, total_epochs=cfg.epochs, steps_per_epoch=batches_per_epoch,
  137. decay_steps=cfg.decay_epochs, decay_rate=cfg.decay_rate,
  138. warmup_steps=cfg.warmup_epochs, warmup_lr_init=cfg.warmup_lr_init,
  139. global_epoch=cfg.resume_start_epoch))
  140. if cfg.opt == 'sgd':
  141. optimizer = SGD(net.trainable_params(), learning_rate=lr, momentum=cfg.momentum,
  142. weight_decay=cfg.weight_decay,
  143. loss_scale=cfg.loss_scale
  144. )
  145. elif cfg.opt == 'rmsprop':
  146. optimizer = RMSProp(net.trainable_params(), learning_rate=lr, decay=0.9, weight_decay=cfg.weight_decay,
  147. momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale
  148. )
  149. loss.add_flags_recursive(fp32=True, fp16=False)
  150. if args.resume:
  151. ckpt = load_checkpoint(args.resume)
  152. load_param_into_net(net, ckpt)
  153. model = Model(net, loss, optimizer,
  154. loss_scale_manager=loss_scale_manager,
  155. amp_level=cfg.amp_level
  156. )
  157. callbacks = callbacks if is_master else []
  158. if args.resume:
  159. real_epoch = cfg.epochs - cfg.resume_start_epoch
  160. model.train(real_epoch, train_dataset,
  161. callbacks=callbacks, dataset_sink_mode=True)
  162. else:
  163. model.train(cfg.epochs, train_dataset,
  164. callbacks=callbacks, dataset_sink_mode=True)
  165. if __name__ == '__main__':
  166. main()