You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 10 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. #################train vgg16 example on cifar10########################
  17. """
  18. import argparse
  19. import datetime
  20. import os
  21. import mindspore.nn as nn
  22. from mindspore import Tensor
  23. from mindspore import context
  24. from mindspore.communication.management import init, get_rank, get_group_size
  25. from mindspore.nn.optim.momentum import Momentum
  26. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
  27. from mindspore.train.model import Model
  28. from mindspore.context import ParallelMode
  29. from mindspore.train.serialization import load_param_into_net, load_checkpoint
  30. from mindspore.train.loss_scale_manager import FixedLossScaleManager
  31. from mindspore.common import set_seed
  32. from src.dataset import vgg_create_dataset
  33. from src.dataset import classification_dataset
  34. from src.crossentropy import CrossEntropy
  35. from src.warmup_step_lr import warmup_step_lr
  36. from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr
  37. from src.warmup_step_lr import lr_steps
  38. from src.utils.logging import get_logger
  39. from src.utils.util import get_param_groups
  40. from src.vgg import vgg16
  41. set_seed(1)
  42. def parse_args(cloud_args=None):
  43. """parameters"""
  44. parser = argparse.ArgumentParser('mindspore classification training')
  45. parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
  46. help='device where the code will be implemented. (Default: Ascend)')
  47. parser.add_argument('--device_id', type=int, default=1, help='device id of GPU or Ascend. (Default: None)')
  48. # dataset related
  49. parser.add_argument('--dataset', type=str, choices=["cifar10", "imagenet2012"], default="cifar10")
  50. parser.add_argument('--data_path', type=str, default='', help='train data dir')
  51. # network related
  52. parser.add_argument('--pre_trained', default='', type=str, help='model_path, local pretrained model to load')
  53. parser.add_argument('--lr_gamma', type=float, default=0.1,
  54. help='decrease lr by a factor of exponential lr_scheduler')
  55. parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler')
  56. parser.add_argument('--T_max', type=int, default=150, help='T-max in cosine_annealing scheduler')
  57. # logging and checkpoint related
  58. parser.add_argument('--log_interval', type=int, default=100, help='logging interval')
  59. parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
  60. parser.add_argument('--ckpt_interval', type=int, default=5, help='ckpt_interval')
  61. parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
  62. # distributed related
  63. parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
  64. parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
  65. parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
  66. args_opt = parser.parse_args()
  67. args_opt = merge_args(args_opt, cloud_args)
  68. if args_opt.dataset == "cifar10":
  69. from src.config import cifar_cfg as cfg
  70. else:
  71. from src.config import imagenet_cfg as cfg
  72. args_opt.label_smooth = cfg.label_smooth
  73. args_opt.label_smooth_factor = cfg.label_smooth_factor
  74. args_opt.lr_scheduler = cfg.lr_scheduler
  75. args_opt.loss_scale = cfg.loss_scale
  76. args_opt.max_epoch = cfg.max_epoch
  77. args_opt.warmup_epochs = cfg.warmup_epochs
  78. args_opt.lr = cfg.lr
  79. args_opt.lr_init = cfg.lr_init
  80. args_opt.lr_max = cfg.lr_max
  81. args_opt.momentum = cfg.momentum
  82. args_opt.weight_decay = cfg.weight_decay
  83. args_opt.per_batch_size = cfg.batch_size
  84. args_opt.num_classes = cfg.num_classes
  85. args_opt.buffer_size = cfg.buffer_size
  86. args_opt.ckpt_save_max = cfg.keep_checkpoint_max
  87. args_opt.pad_mode = cfg.pad_mode
  88. args_opt.padding = cfg.padding
  89. args_opt.has_bias = cfg.has_bias
  90. args_opt.batch_norm = cfg.batch_norm
  91. args_opt.initialize_mode = cfg.initialize_mode
  92. args_opt.has_dropout = cfg.has_dropout
  93. args_opt.lr_epochs = list(map(int, cfg.lr_epochs.split(',')))
  94. args_opt.image_size = list(map(int, cfg.image_size.split(',')))
  95. return args_opt
  96. def merge_args(args_opt, cloud_args):
  97. """dictionary"""
  98. args_dict = vars(args_opt)
  99. if isinstance(cloud_args, dict):
  100. for key_arg in cloud_args.keys():
  101. val = cloud_args[key_arg]
  102. if key_arg in args_dict and val:
  103. arg_type = type(args_dict[key_arg])
  104. if arg_type is not None:
  105. val = arg_type(val)
  106. args_dict[key_arg] = val
  107. return args_opt
  108. if __name__ == '__main__':
  109. args = parse_args()
  110. device_num = int(os.environ.get("DEVICE_NUM", 1))
  111. if args.is_distributed:
  112. if args.device_target == "Ascend":
  113. init()
  114. context.set_context(device_id=args.device_id)
  115. elif args.device_target == "GPU":
  116. init()
  117. args.rank = get_rank()
  118. args.group_size = get_group_size()
  119. device_num = args.group_size
  120. context.reset_auto_parallel_context()
  121. context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
  122. parameter_broadcast=True, gradients_mean=True)
  123. else:
  124. context.set_context(device_id=args.device_id)
  125. context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
  126. # select for master rank save ckpt or all rank save, compatible for model parallel
  127. args.rank_save_ckpt_flag = 0
  128. if args.is_save_on_master:
  129. if args.rank == 0:
  130. args.rank_save_ckpt_flag = 1
  131. else:
  132. args.rank_save_ckpt_flag = 1
  133. # logger
  134. args.outputs_dir = os.path.join(args.ckpt_path,
  135. datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
  136. args.logger = get_logger(args.outputs_dir, args.rank)
  137. if args.dataset == "cifar10":
  138. dataset = vgg_create_dataset(args.data_path, args.image_size, args.per_batch_size, args.rank, args.group_size)
  139. else:
  140. dataset = classification_dataset(args.data_path, args.image_size, args.per_batch_size,
  141. args.rank, args.group_size)
  142. batch_num = dataset.get_dataset_size()
  143. args.steps_per_epoch = dataset.get_dataset_size()
  144. args.logger.save_args(args)
  145. # network
  146. args.logger.important_info('start create network')
  147. # get network and init
  148. network = vgg16(args.num_classes, args)
  149. # pre_trained
  150. if args.pre_trained:
  151. load_param_into_net(network, load_checkpoint(args.pre_trained))
  152. # lr scheduler
  153. if args.lr_scheduler == 'exponential':
  154. lr = warmup_step_lr(args.lr,
  155. args.lr_epochs,
  156. args.steps_per_epoch,
  157. args.warmup_epochs,
  158. args.max_epoch,
  159. gamma=args.lr_gamma,
  160. )
  161. elif args.lr_scheduler == 'cosine_annealing':
  162. lr = warmup_cosine_annealing_lr(args.lr,
  163. args.steps_per_epoch,
  164. args.warmup_epochs,
  165. args.max_epoch,
  166. args.T_max,
  167. args.eta_min)
  168. elif args.lr_scheduler == 'step':
  169. lr = lr_steps(0, lr_init=args.lr_init, lr_max=args.lr_max, warmup_epochs=args.warmup_epochs,
  170. total_epochs=args.max_epoch, steps_per_epoch=batch_num)
  171. else:
  172. raise NotImplementedError(args.lr_scheduler)
  173. # optimizer
  174. opt = Momentum(params=get_param_groups(network),
  175. learning_rate=Tensor(lr),
  176. momentum=args.momentum,
  177. weight_decay=args.weight_decay,
  178. loss_scale=args.loss_scale)
  179. if args.dataset == "cifar10":
  180. loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
  181. model = Model(network, loss_fn=loss, optimizer=opt, metrics={'acc'},
  182. amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
  183. else:
  184. if not args.label_smooth:
  185. args.label_smooth_factor = 0.0
  186. loss = CrossEntropy(smooth_factor=args.label_smooth_factor, num_classes=args.num_classes)
  187. loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
  188. model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, amp_level="O2")
  189. # define callbacks
  190. time_cb = TimeMonitor(data_size=batch_num)
  191. loss_cb = LossMonitor(per_print_times=batch_num)
  192. callbacks = [time_cb, loss_cb]
  193. if args.rank_save_ckpt_flag:
  194. ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch,
  195. keep_checkpoint_max=args.ckpt_save_max)
  196. save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
  197. ckpt_cb = ModelCheckpoint(config=ckpt_config,
  198. directory=save_ckpt_path,
  199. prefix='{}'.format(args.rank))
  200. callbacks.append(ckpt_cb)
  201. model.train(args.max_epoch, dataset, callbacks=callbacks)