You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ===========================================================================
  15. """DSCNN train."""
  16. import os
  17. import datetime
  18. import argparse
  19. import numpy as np
  20. from mindspore import context
  21. from mindspore import Tensor, Model
  22. from mindspore.nn.optim import Momentum
  23. from mindspore.common import dtype as mstype
  24. from mindspore.train.serialization import load_checkpoint
  25. from src.config import train_config
  26. from src.log import get_logger
  27. from src.dataset import audio_dataset
  28. from src.ds_cnn import DSCNN
  29. from src.loss import CrossEntropy
  30. from src.lr_scheduler import MultiStepLR, CosineAnnealingLR
  31. from src.callback import ProgressMonitor, callback_func
  32. def get_top5_acc(top5_arg, gt_class):
  33. sub_count = 0
  34. for top5, gt in zip(top5_arg, gt_class):
  35. if gt in top5:
  36. sub_count += 1
  37. return sub_count
  38. def val(args, model, val_dataset):
  39. '''Eval.'''
  40. val_dataloader = val_dataset.create_tuple_iterator()
  41. img_tot = 0
  42. top1_correct = 0
  43. top5_correct = 0
  44. for data, gt_classes in val_dataloader:
  45. output = model.predict(Tensor(data, mstype.float32))
  46. output = output.asnumpy()
  47. top1_output = np.argmax(output, (-1))
  48. top5_output = np.argsort(output)[:, -5:]
  49. gt_classes = gt_classes.asnumpy()
  50. t1_correct = np.equal(top1_output, gt_classes).sum()
  51. top1_correct += t1_correct
  52. top5_correct += get_top5_acc(top5_output, gt_classes)
  53. img_tot += output.shape[0]
  54. results = [[top1_correct], [top5_correct], [img_tot]]
  55. results = np.array(results)
  56. top1_correct = results[0, 0]
  57. top5_correct = results[1, 0]
  58. img_tot = results[2, 0]
  59. acc1 = 100.0 * top1_correct / img_tot
  60. acc5 = 100.0 * top5_correct / img_tot
  61. if acc1 > args.best_acc:
  62. args.best_acc = acc1
  63. args.best_epoch = args.epoch_cnt - 1
  64. args.logger.info('Eval: top1_cor:{}, top5_cor:{}, tot:{}, acc@1={:.2f}%, acc@5={:.2f}%' \
  65. .format(top1_correct, top5_correct, img_tot, acc1, acc5))
  66. def trainval(args, model, train_dataset, val_dataset, cb):
  67. callbacks = callback_func(args, cb, 'epoch{}'.format(args.epoch_cnt))
  68. model.train(args.val_interval, train_dataset, callbacks=callbacks, dataset_sink_mode=args.dataset_sink_mode)
  69. val(args, model, val_dataset)
  70. def train():
  71. '''Train.'''
  72. parser = argparse.ArgumentParser()
  73. parser.add_argument('--device_id', type=int, default=1, help='which device the model will be trained on')
  74. args, model_settings = train_config(parser)
  75. context.set_context(mode=context.GRAPH_MODE, device_id=args.device_id, enable_auto_mixed_precision=True)
  76. args.rank_save_ckpt_flag = 1
  77. # Logger
  78. args.outputs_dir = os.path.join(args.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
  79. args.logger = get_logger(args.outputs_dir)
  80. # Dataloader: train, val
  81. train_dataset = audio_dataset(args.feat_dir, 'training', model_settings['spectrogram_length'],
  82. model_settings['dct_coefficient_count'], args.per_batch_size)
  83. args.steps_per_epoch = train_dataset.get_dataset_size()
  84. val_dataset = audio_dataset(args.feat_dir, 'validation', model_settings['spectrogram_length'],
  85. model_settings['dct_coefficient_count'], args.per_batch_size)
  86. # show args
  87. args.logger.save_args(args)
  88. # Network
  89. args.logger.important_info('start create network')
  90. network = DSCNN(model_settings, args.model_size_info)
  91. # Load pretrain model
  92. if os.path.isfile(args.pretrained):
  93. load_checkpoint(args.pretrained, network)
  94. args.logger.info('load model {} success'.format(args.pretrained))
  95. # Loss
  96. criterion = CrossEntropy(num_classes=model_settings['label_count'])
  97. # LR scheduler
  98. if args.lr_scheduler == 'multistep':
  99. lr_scheduler = MultiStepLR(args.lr, args.lr_epochs, args.lr_gamma, args.steps_per_epoch,
  100. args.max_epoch, warmup_epochs=args.warmup_epochs)
  101. elif args.lr_scheduler == 'cosine_annealing':
  102. lr_scheduler = CosineAnnealingLR(args.lr, args.T_max, args.steps_per_epoch, args.max_epoch,
  103. warmup_epochs=args.warmup_epochs, eta_min=args.eta_min)
  104. else:
  105. raise NotImplementedError(args.lr_scheduler)
  106. lr_schedule = lr_scheduler.get_lr()
  107. # Optimizer
  108. opt = Momentum(params=network.trainable_params(),
  109. learning_rate=Tensor(lr_schedule),
  110. momentum=args.momentum,
  111. weight_decay=args.weight_decay)
  112. model = Model(network, loss_fn=criterion, optimizer=opt, amp_level='O0')
  113. # Training
  114. args.epoch_cnt = 0
  115. args.best_epoch = 0
  116. args.best_acc = 0
  117. progress_cb = ProgressMonitor(args)
  118. while args.epoch_cnt + args.val_interval < args.max_epoch:
  119. trainval(args, model, train_dataset, val_dataset, progress_cb)
  120. rest_ep = args.max_epoch - args.epoch_cnt
  121. if rest_ep > 0:
  122. trainval(args, model, train_dataset, val_dataset, progress_cb)
  123. args.logger.info('Best epoch:{} acc:{:.2f}%'.format(args.best_epoch, args.best_acc))
  124. if __name__ == "__main__":
  125. train()