You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import os
  16. import argparse
  17. import numpy as np
  18. from mindspore import context, Tensor
  19. from mindspore.context import ParallelMode
  20. from mindspore.communication.management import init, get_group_size, get_rank
  21. from mindspore.train import Model
  22. from mindspore.train.callback import TimeMonitor, LossMonitor, ModelCheckpoint, CheckpointConfig
  23. from mindspore.nn.optim import Adam
  24. from mindspore.common import set_seed
  25. from src.config import config
  26. from src.model import get_pose_net
  27. from src.network_define import JointsMSELoss, WithLossCell
  28. from src.dataset import keypoint_dataset
  29. set_seed(1)
  30. device_id = int(os.getenv('DEVICE_ID'))
  31. def get_lr(begin_epoch,
  32. total_epochs,
  33. steps_per_epoch,
  34. lr_init=0.1,
  35. factor=0.1,
  36. epoch_number_to_drop=(90, 120)
  37. ):
  38. """
  39. Generate learning rate array.
  40. Args:
  41. begin_epoch (int): Initial epoch of training.
  42. total_epochs (int): Total epoch of training.
  43. steps_per_epoch (float): Steps of one epoch.
  44. lr_init (float): Initial learning rate. Default: 0.316.
  45. factor:Factor of lr to drop.
  46. epoch_number_to_drop:Learing rate will drop after these epochs.
  47. Returns:
  48. np.array, learning rate array.
  49. """
  50. lr_each_step = []
  51. total_steps = steps_per_epoch * total_epochs
  52. step_number_to_drop = [steps_per_epoch * x for x in epoch_number_to_drop]
  53. for i in range(int(total_steps)):
  54. if i in step_number_to_drop:
  55. lr_init = lr_init * factor
  56. lr_each_step.append(lr_init)
  57. current_step = steps_per_epoch * begin_epoch
  58. lr_each_step = np.array(lr_each_step, dtype=np.float32)
  59. learning_rate = lr_each_step[current_step:]
  60. return learning_rate
  61. def parse_args():
  62. parser = argparse.ArgumentParser(description="Simpleposenet training")
  63. parser.add_argument("--run-distribute",
  64. help="Run distribute, default is false.",
  65. action='store_true')
  66. parser.add_argument('--ckpt-path', type=str, help='ckpt path to save')
  67. parser.add_argument('--batch-size', type=int, help='training batch size')
  68. args = parser.parse_args()
  69. return args
  70. def main():
  71. # load parse and config
  72. print("loading parse...")
  73. args = parse_args()
  74. if args.batch_size:
  75. config.TRAIN.BATCH_SIZE = args.batch_size
  76. print('batch size :{}'.format(config.TRAIN.BATCH_SIZE))
  77. # distribution and context
  78. context.set_context(mode=context.GRAPH_MODE,
  79. device_target="Ascend",
  80. save_graphs=False,
  81. device_id=device_id)
  82. if args.run_distribute:
  83. init()
  84. rank = get_rank()
  85. device_num = get_group_size()
  86. context.set_auto_parallel_context(device_num=device_num,
  87. parallel_mode=ParallelMode.DATA_PARALLEL,
  88. gradients_mean=True)
  89. else:
  90. rank = 0
  91. device_num = 1
  92. # only rank = 0 can write
  93. rank_save_flag = False
  94. if rank == 0 or device_num == 1:
  95. rank_save_flag = True
  96. # create dataset
  97. dataset, _ = keypoint_dataset(config,
  98. rank=rank,
  99. group_size=device_num,
  100. train_mode=True,
  101. num_parallel_workers=8)
  102. # network
  103. net = get_pose_net(config, True, ckpt_path=config.MODEL.PRETRAINED)
  104. loss = JointsMSELoss(use_target_weight=True)
  105. net_with_loss = WithLossCell(net, loss)
  106. # lr schedule and optim
  107. dataset_size = dataset.get_dataset_size()
  108. lr = Tensor(get_lr(config.TRAIN.BEGIN_EPOCH,
  109. config.TRAIN.END_EPOCH,
  110. dataset_size,
  111. lr_init=config.TRAIN.LR,
  112. factor=config.TRAIN.LR_FACTOR,
  113. epoch_number_to_drop=config.TRAIN.LR_STEP))
  114. opt = Adam(net.trainable_params(), learning_rate=lr)
  115. # callback
  116. time_cb = TimeMonitor(data_size=dataset_size)
  117. loss_cb = LossMonitor()
  118. cb = [time_cb, loss_cb]
  119. if args.ckpt_path and rank_save_flag:
  120. config_ck = CheckpointConfig(save_checkpoint_steps=dataset_size, keep_checkpoint_max=20)
  121. ckpoint_cb = ModelCheckpoint(prefix="simplepose", directory=args.ckpt_path, config=config_ck)
  122. cb.append(ckpoint_cb)
  123. # train model
  124. model = Model(net_with_loss, loss_fn=None, optimizer=opt, amp_level="O2")
  125. epoch_size = config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH
  126. print('start training, epoch size = %d' % epoch_size)
  127. model.train(epoch_size, dataset, callbacks=cb)
  128. if __name__ == '__main__':
  129. main()