You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 7.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """train imagenet"""
  16. import os
  17. import argparse
  18. import math
  19. import numpy as np
  20. from mindspore.communication import init, get_rank
  21. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
  22. from mindspore.train.model import ParallelMode
  23. from mindspore.train.loss_scale_manager import FixedLossScaleManager
  24. from mindspore import Model
  25. from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
  26. from mindspore.nn import RMSProp
  27. from mindspore import Tensor
  28. from mindspore import context
  29. from mindspore.common import set_seed
  30. from mindspore.common.initializer import XavierUniform, initializer
  31. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  32. from src.inceptionv4 import Inceptionv4
  33. from src.dataset import create_dataset, device_num
  34. from src.config import config
  35. os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
  36. set_seed(1)
  37. def generate_cosine_lr(steps_per_epoch, total_epochs,
  38. lr_init=config.lr_init,
  39. lr_end=config.lr_end,
  40. lr_max=config.lr_max,
  41. warmup_epochs=config.warmup_epochs):
  42. """
  43. Applies cosine decay to generate learning rate array.
  44. Args:
  45. steps_per_epoch(int): steps number per epoch
  46. total_epochs(int): all epoch in training.
  47. lr_init(float): init learning rate.
  48. lr_end(float): end learning rate
  49. lr_max(float): max learning rate.
  50. warmup_steps(int): all steps in warmup epochs.
  51. Returns:
  52. np.array, learning rate array.
  53. """
  54. total_steps = steps_per_epoch * total_epochs
  55. warmup_steps = steps_per_epoch * warmup_epochs
  56. decay_steps = total_steps - warmup_steps
  57. lr_each_step = []
  58. for i in range(total_steps):
  59. if i < warmup_steps:
  60. lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
  61. lr = float(lr_init) + lr_inc * (i + 1)
  62. else:
  63. cosine_decay = 0.5 * (1 + math.cos(math.pi * (i - warmup_steps) / decay_steps))
  64. lr = (lr_max - lr_end) * cosine_decay + lr_end
  65. lr_each_step.append(lr)
  66. learning_rate = np.array(lr_each_step).astype(np.float32)
  67. current_step = steps_per_epoch * (config.start_epoch - 1)
  68. learning_rate = learning_rate[current_step:]
  69. return learning_rate
  70. def inception_v4_train():
  71. """
  72. Train Inceptionv4 in data parallelism
  73. """
  74. print('epoch_size: {} batch_size: {} class_num {}'.format(config.epoch_size, config.batch_size, config.num_classes))
  75. context.set_context(mode=context.GRAPH_MODE, device_target=args.platform)
  76. if args.platform == "Ascend":
  77. context.set_context(device_id=args.device_id)
  78. context.set_context(enable_graph_kernel=False)
  79. rank = 0
  80. if device_num > 1:
  81. if args.platform == "Ascend":
  82. init(backend_name='hccl')
  83. elif args.platform == "GPU":
  84. init()
  85. else:
  86. raise ValueError("Unsupported device target.")
  87. rank = get_rank()
  88. context.set_auto_parallel_context(device_num=device_num,
  89. parallel_mode=ParallelMode.DATA_PARALLEL,
  90. gradients_mean=True,
  91. all_reduce_fusion_config=[200, 400])
  92. # create dataset
  93. train_dataset = create_dataset(dataset_path=args.dataset_path, do_train=True,
  94. repeat_num=1, batch_size=config.batch_size, shard_id=rank)
  95. train_step_size = train_dataset.get_dataset_size()
  96. # create model
  97. net = Inceptionv4(classes=config.num_classes)
  98. # loss
  99. loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  100. # learning rate
  101. lr = Tensor(generate_cosine_lr(steps_per_epoch=train_step_size, total_epochs=config.epoch_size))
  102. decayed_params = []
  103. no_decayed_params = []
  104. for param in net.trainable_params():
  105. if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
  106. decayed_params.append(param)
  107. else:
  108. no_decayed_params.append(param)
  109. for param in net.trainable_params():
  110. if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
  111. param.set_data(initializer(XavierUniform(), param.data.shape, param.data.dtype))
  112. group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
  113. {'params': no_decayed_params},
  114. {'order_params': net.trainable_params()}]
  115. opt = RMSProp(group_params, lr, decay=config.decay, epsilon=config.epsilon, weight_decay=config.weight_decay,
  116. momentum=config.momentum, loss_scale=config.loss_scale)
  117. if args.device_id == 0:
  118. print(lr)
  119. print(train_step_size)
  120. if args.resume:
  121. ckpt = load_checkpoint(args.resume)
  122. load_param_into_net(net, ckpt)
  123. loss_scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
  124. if args.platform == "Ascend":
  125. model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc', 'top_1_accuracy', 'top_5_accuracy'},
  126. loss_scale_manager=loss_scale_manager, amp_level=config.amp_level)
  127. elif args.platform == "GPU":
  128. model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc', 'top_1_accuracy', 'top_5_accuracy'},
  129. loss_scale_manager=loss_scale_manager, amp_level='O0')
  130. else:
  131. raise ValueError("Unsupported device target.")
  132. # define callbacks
  133. performance_cb = TimeMonitor(data_size=train_step_size)
  134. loss_cb = LossMonitor(per_print_times=train_step_size)
  135. ckp_save_step = config.save_checkpoint_epochs * train_step_size
  136. config_ck = CheckpointConfig(save_checkpoint_steps=ckp_save_step, keep_checkpoint_max=config.keep_checkpoint_max)
  137. ckpoint_cb = ModelCheckpoint(prefix=f"inceptionV4-train-rank{rank}",
  138. directory='ckpts_rank_' + str(rank), config=config_ck)
  139. callbacks = [performance_cb, loss_cb]
  140. if device_num > 1 and config.is_save_on_master:
  141. if args.device_id == 0:
  142. callbacks.append(ckpoint_cb)
  143. else:
  144. callbacks.append(ckpoint_cb)
  145. # train model
  146. model.train(config.epoch_size, train_dataset, callbacks=callbacks, dataset_sink_mode=True)
  147. def parse_args():
  148. '''parse_args'''
  149. arg_parser = argparse.ArgumentParser(description='InceptionV4 image classification training')
  150. arg_parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
  151. arg_parser.add_argument('--device_id', type=int, default=0, help='device id')
  152. arg_parser.add_argument('--platform', type=str, default='Ascend', choices=("Ascend", "GPU"),
  153. help='Platform, support Ascend, GPU.')
  154. arg_parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
  155. args_opt = arg_parser.parse_args()
  156. return args_opt
  157. if __name__ == '__main__':
  158. args = parse_args()
  159. inception_v4_train()
  160. print('Inceptionv4 training success!')