You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 8.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """train_imagenet."""
  16. import time
  17. import argparse
  18. import numpy as np
  19. from mindspore import context
  20. from mindspore import Tensor
  21. from mindspore import nn
  22. from mindspore.nn.optim.momentum import Momentum
  23. from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
  24. from mindspore.nn.loss.loss import _Loss
  25. from mindspore.ops import operations as P
  26. from mindspore.ops import functional as F
  27. from mindspore.common import dtype as mstype
  28. from mindspore.train.model import Model
  29. from mindspore.context import ParallelMode
  30. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
  31. from mindspore.train.loss_scale_manager import FixedLossScaleManager
  32. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  33. from mindspore.common import set_seed
  34. from mindspore.communication.management import init, get_group_size, get_rank
  35. from src.dataset import create_dataset
  36. from src.lr_generator import get_lr
  37. from src.config import config_gpu
  38. from src.mobilenetV3 import mobilenet_v3_large
  39. set_seed(1)
  40. parser = argparse.ArgumentParser(description='Image classification')
  41. parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
  42. parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
  43. parser.add_argument('--device_target', type=str, default="GPU", help='run device_target')
  44. args_opt = parser.parse_args()
  45. if args_opt.device_target == "GPU":
  46. context.set_context(mode=context.GRAPH_MODE,
  47. device_target="GPU",
  48. save_graphs=False)
  49. init()
  50. context.set_auto_parallel_context(device_num=get_group_size(),
  51. parallel_mode=ParallelMode.DATA_PARALLEL,
  52. mirror_mean=True)
  53. else:
  54. raise ValueError("Unsupported device_target.")
  55. class CrossEntropyWithLabelSmooth(_Loss):
  56. """
  57. CrossEntropyWith LabelSmooth.
  58. Args:
  59. smooth_factor (float): smooth factor, default=0.
  60. num_classes (int): num classes
  61. Returns:
  62. None.
  63. Examples:
  64. >>> CrossEntropyWithLabelSmooth(smooth_factor=0., num_classes=1000)
  65. """
  66. def __init__(self, smooth_factor=0., num_classes=1000):
  67. super(CrossEntropyWithLabelSmooth, self).__init__()
  68. self.onehot = P.OneHot()
  69. self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
  70. self.off_value = Tensor(1.0 * smooth_factor /
  71. (num_classes - 1), mstype.float32)
  72. self.ce = nn.SoftmaxCrossEntropyWithLogits()
  73. self.mean = P.ReduceMean(False)
  74. self.cast = P.Cast()
  75. def construct(self, logit, label):
  76. one_hot_label = self.onehot(self.cast(label, mstype.int32), F.shape(logit)[1],
  77. self.on_value, self.off_value)
  78. out_loss = self.ce(logit, one_hot_label)
  79. out_loss = self.mean(out_loss, 0)
  80. return out_loss
  81. class Monitor(Callback):
  82. """
  83. Monitor loss and time.
  84. Args:
  85. lr_init (numpy array): train lr
  86. Returns:
  87. None
  88. Examples:
  89. >>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy())
  90. """
  91. def __init__(self, lr_init=None):
  92. super(Monitor, self).__init__()
  93. self.lr_init = lr_init
  94. self.lr_init_len = len(lr_init)
  95. def epoch_begin(self, run_context):
  96. self.losses = []
  97. self.epoch_time = time.time()
  98. def epoch_end(self, run_context):
  99. cb_params = run_context.original_args()
  100. epoch_mseconds = (time.time() - self.epoch_time) * 1000
  101. per_step_mseconds = epoch_mseconds / cb_params.batch_num
  102. print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds,
  103. per_step_mseconds,
  104. np.mean(self.losses)))
  105. def step_begin(self, run_context):
  106. self.step_time = time.time()
  107. def step_end(self, run_context):
  108. cb_params = run_context.original_args()
  109. step_mseconds = (time.time() - self.step_time) * 1000
  110. step_loss = cb_params.net_outputs
  111. if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
  112. step_loss = step_loss[0]
  113. if isinstance(step_loss, Tensor):
  114. step_loss = np.mean(step_loss.asnumpy())
  115. self.losses.append(step_loss)
  116. cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
  117. print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.3f}]".format(
  118. cb_params.cur_epoch_num -
  119. 1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss,
  120. np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]))
  121. if __name__ == '__main__':
  122. if args_opt.device_target == "GPU":
  123. # train on gpu
  124. print("train args: ", args_opt)
  125. print("cfg: ", config_gpu)
  126. # define net
  127. net = mobilenet_v3_large(num_classes=config_gpu.num_classes)
  128. # define loss
  129. if config_gpu.label_smooth > 0:
  130. loss = CrossEntropyWithLabelSmooth(
  131. smooth_factor=config_gpu.label_smooth, num_classes=config_gpu.num_classes)
  132. else:
  133. loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
  134. # define dataset
  135. epoch_size = config_gpu.epoch_size
  136. dataset = create_dataset(dataset_path=args_opt.dataset_path,
  137. do_train=True,
  138. config=config_gpu,
  139. device_target=args_opt.device_target,
  140. repeat_num=1,
  141. batch_size=config_gpu.batch_size)
  142. step_size = dataset.get_dataset_size()
  143. # resume
  144. if args_opt.pre_trained:
  145. param_dict = load_checkpoint(args_opt.pre_trained)
  146. load_param_into_net(net, param_dict)
  147. # define optimizer
  148. loss_scale = FixedLossScaleManager(
  149. config_gpu.loss_scale, drop_overflow_update=False)
  150. lr = Tensor(get_lr(global_step=0,
  151. lr_init=0,
  152. lr_end=0,
  153. lr_max=config_gpu.lr,
  154. warmup_epochs=config_gpu.warmup_epochs,
  155. total_epochs=epoch_size,
  156. steps_per_epoch=step_size))
  157. opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_gpu.momentum,
  158. config_gpu.weight_decay, config_gpu.loss_scale)
  159. # define model
  160. model = Model(net, loss_fn=loss, optimizer=opt,
  161. loss_scale_manager=loss_scale)
  162. cb = [Monitor(lr_init=lr.asnumpy())]
  163. ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
  164. if config_gpu.save_checkpoint:
  165. config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size,
  166. keep_checkpoint_max=config_gpu.keep_checkpoint_max)
  167. ckpt_cb = ModelCheckpoint(prefix="mobilenetV3", directory=ckpt_save_dir, config=config_ck)
  168. cb += [ckpt_cb]
  169. # begine train
  170. model.train(epoch_size, dataset, callbacks=cb)