You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 8.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Transformer training script."""
  16. import time
  17. import argparse
  18. import random
  19. import numpy as np
  20. import mindspore.common.dtype as mstype
  21. from mindspore.common.tensor import Tensor
  22. from mindspore.nn.optim import Adam
  23. from mindspore.train.model import Model
  24. from mindspore.train.loss_scale_manager import DynamicLossScaleManager
  25. from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
  26. from mindspore.train.callback import Callback, TimeMonitor
  27. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  28. import mindspore.dataset.engine as de
  29. import mindspore.communication.management as D
  30. from mindspore.train.parallel_utils import ParallelMode
  31. from mindspore import context
  32. from src.transformer_for_train import TransformerTrainOneStepCell, TransformerNetworkWithLoss, \
  33. TransformerTrainOneStepWithLossScaleCell
  34. from src.config import cfg, transformer_net_cfg
  35. from src.dataset import create_transformer_dataset
  36. from src.lr_schedule import create_dynamic_lr
  37. random_seed = 1
  38. random.seed(random_seed)
  39. np.random.seed(random_seed)
  40. de.config.set_seed(random_seed)
  41. def get_ms_timestamp():
  42. t = time.time()
  43. return int(round(t * 1000))
  44. time_stamp_init = False
  45. time_stamp_first = 0
  46. class LossCallBack(Callback):
  47. """
  48. Monitor the loss in training.
  49. If the loss is NAN or INF terminating training.
  50. Note:
  51. If per_print_times is 0 do not print loss.
  52. Args:
  53. per_print_times (int): Print loss every times. Default: 1.
  54. """
  55. def __init__(self, per_print_times=1):
  56. super(LossCallBack, self).__init__()
  57. if not isinstance(per_print_times, int) or per_print_times < 0:
  58. raise ValueError("print_step must be int and >= 0.")
  59. self._per_print_times = per_print_times
  60. global time_stamp_init, time_stamp_first
  61. if not time_stamp_init:
  62. time_stamp_first = get_ms_timestamp()
  63. time_stamp_init = True
  64. def step_end(self, run_context):
  65. global time_stamp_first
  66. time_stamp_current = get_ms_timestamp()
  67. cb_params = run_context.original_args()
  68. print("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - time_stamp_first,
  69. cb_params.cur_epoch_num, cb_params.cur_step_num,
  70. str(cb_params.net_outputs)))
  71. with open("./loss.log", "a+") as f:
  72. f.write("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - time_stamp_first,
  73. cb_params.cur_epoch_num,
  74. cb_params.cur_step_num,
  75. str(cb_params.net_outputs)))
  76. f.write('\n')
  77. def argparse_init():
  78. """
  79. Argparse init.
  80. """
  81. parser = argparse.ArgumentParser(description='transformer')
  82. parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.")
  83. parser.add_argument("--epoch_size", type=int, default=52, help="Epoch size, default is 52.")
  84. parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
  85. parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
  86. parser.add_argument("--enable_lossscale", type=str, default="true", help="Use lossscale or not, default is true.")
  87. parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
  88. parser.add_argument("--enable_data_sink", type=str, default="false", help="Enable data sink, default is false.")
  89. parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path")
  90. parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, "
  91. "default is true.")
  92. parser.add_argument("--save_checkpoint_steps", type=int, default=2500, help="Save checkpoint steps, "
  93. "default is 2500.")
  94. parser.add_argument("--save_checkpoint_num", type=int, default=30, help="Save checkpoint numbers, default is 30.")
  95. parser.add_argument("--save_checkpoint_path", type=str, default="./checkpoint/", help="Save checkpoint file path, "
  96. "default is ./checkpoint/")
  97. parser.add_argument("--data_path", type=str, default="", help="Data path, it is better to use absolute path")
  98. return parser
  99. def run_transformer_train():
  100. """
  101. Transformer training.
  102. """
  103. parser = argparse_init()
  104. args, _ = parser.parse_known_args()
  105. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
  106. context.set_context(reserve_class_name_in_scope=False, enable_auto_mixed_precision=False)
  107. if args.distribute == "true":
  108. device_num = args.device_num
  109. context.reset_auto_parallel_context()
  110. context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
  111. parameter_broadcast=True, device_num=device_num)
  112. D.init()
  113. rank_id = args.device_id % device_num
  114. else:
  115. device_num = 1
  116. rank_id = 0
  117. dataset, repeat_count = create_transformer_dataset(epoch_count=args.epoch_size, rank_size=device_num,
  118. rank_id=rank_id, do_shuffle=args.do_shuffle,
  119. enable_data_sink=args.enable_data_sink,
  120. dataset_path=args.data_path)
  121. netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True)
  122. if args.checkpoint_path:
  123. parameter_dict = load_checkpoint(args.checkpoint_path)
  124. load_param_into_net(netwithloss, parameter_dict)
  125. lr = Tensor(create_dynamic_lr(schedule="constant*rsqrt_hidden*linear_warmup*rsqrt_decay",
  126. training_steps=dataset.get_dataset_size()*args.epoch_size,
  127. learning_rate=cfg.lr_schedule.learning_rate,
  128. warmup_steps=cfg.lr_schedule.warmup_steps,
  129. hidden_size=transformer_net_cfg.hidden_size,
  130. start_decay_step=cfg.lr_schedule.start_decay_step,
  131. min_lr=cfg.lr_schedule.min_lr), mstype.float32)
  132. optimizer = Adam(netwithloss.trainable_params(), lr)
  133. callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack()]
  134. if args.enable_save_ckpt == "true":
  135. ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_steps,
  136. keep_checkpoint_max=args.save_checkpoint_num)
  137. ckpoint_cb = ModelCheckpoint(prefix='transformer', directory=args.save_checkpoint_path, config=ckpt_config)
  138. callbacks.append(ckpoint_cb)
  139. if args.enable_lossscale == "true":
  140. scale_manager = DynamicLossScaleManager(init_loss_scale=cfg.init_loss_scale_value,
  141. scale_factor=cfg.scale_factor,
  142. scale_window=cfg.scale_window)
  143. update_cell = scale_manager.get_update_cell()
  144. netwithgrads = TransformerTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
  145. scale_update_cell=update_cell)
  146. else:
  147. netwithgrads = TransformerTrainOneStepCell(netwithloss, optimizer=optimizer)
  148. netwithgrads.set_train(True)
  149. model = Model(netwithgrads)
  150. model.train(repeat_count, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true"))
  151. if __name__ == '__main__':
  152. run_transformer_train()