You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Train api."""
  16. import os
  17. import argparse
  18. import pickle
  19. import numpy as np
  20. import mindspore.common.dtype as mstype
  21. from mindspore.common.tensor import Tensor
  22. from mindspore.nn import Momentum
  23. from mindspore.nn.optim import Adam, Lamb
  24. from mindspore.train.model import Model
  25. from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
  26. from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
  27. from mindspore import context, ParallelMode, Parameter
  28. from mindspore.communication import management as MultiAscend
  29. from mindspore.train.serialization import load_checkpoint
  30. from config import TransformerConfig
  31. from src.dataset import load_dataset
  32. from src.transformer import TransformerNetworkWithLoss, TransformerTrainOneStepWithLossScaleCell
  33. from src.transformer.infer_mass import infer
  34. from src.utils import LossCallBack
  35. from src.utils import one_weight, zero_weight, weight_variable
  36. from src.utils import square_root_schedule
  37. from src.utils.lr_scheduler import polynomial_decay_scheduler, BertLearningRate
  38. parser = argparse.ArgumentParser(description='MASS train entry point.')
  39. parser.add_argument("--config", type=str, required=True, help="model config json file path.")
  40. parser.add_argument("--platform", type=str, required=True, help="model working platform.")
  41. def get_config(config):
  42. config = TransformerConfig.from_json_file(config)
  43. config.compute_type = mstype.float16
  44. config.dtype = mstype.float32
  45. return config
  46. def _train(model, config: TransformerConfig,
  47. pre_training_dataset=None, fine_tune_dataset=None, test_dataset=None,
  48. callbacks: list = None):
  49. """
  50. Train model.
  51. Args:
  52. model (Model): MindSpore model instance.
  53. config (TransformerConfig): Config of mass model.
  54. pre_training_dataset (Dataset): Pre-training dataset.
  55. fine_tune_dataset (Dataset): Fine-tune dataset.
  56. test_dataset (Dataset): Test dataset.
  57. callbacks (list): A list of callbacks.
  58. """
  59. callbacks = callbacks if callbacks else []
  60. if pre_training_dataset is not None:
  61. print(" | Start pre-training job.")
  62. if os.getenv("RANK_SIZE") is not None and int(os.getenv("RANK_SIZE")) > 1:
  63. print(f" | Rank {MultiAscend.get_rank()} Call model train.")
  64. model.train(config.epochs, pre_training_dataset,
  65. callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode,
  66. sink_size=config.dataset_sink_step)
  67. # Test the accuracy of the model.
  68. if test_dataset is not None:
  69. print(" | Start test job.")
  70. result = infer(_config)
  71. with open("validation_res_after_pre_training.bin", "wb") as f:
  72. pickle.dump(result, f, 1)
  73. if fine_tune_dataset is not None:
  74. print(" | Start fine-tuning job.")
  75. model.train(config.epochs, fine_tune_dataset,
  76. callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode,
  77. sink_size=config.dataset_sink_step)
  78. # Test the accuracy of the model.
  79. if test_dataset is not None:
  80. print(" | Start test job.")
  81. result = infer(_config)
  82. with open("validation_res_after_pre_training.bin", "wb") as f:
  83. pickle.dump(result, f, 1)
  84. def _build_training_pipeline(config: TransformerConfig,
  85. pre_training_dataset=None,
  86. fine_tune_dataset=None,
  87. test_dataset=None,
  88. platform="Ascend"):
  89. """
  90. Build training pipeline.
  91. Args:
  92. config (TransformerConfig): Config of mass model.
  93. pre_training_dataset (Dataset): Pre-training dataset.
  94. fine_tune_dataset (Dataset): Fine-tune dataset.
  95. test_dataset (Dataset): Test dataset.
  96. """
  97. net_with_loss = TransformerNetworkWithLoss(config, is_training=True)
  98. net_with_loss.init_parameters_data()
  99. if config.existed_ckpt:
  100. if config.existed_ckpt.endswith(".npz"):
  101. weights = np.load(config.existed_ckpt)
  102. else:
  103. weights = load_checkpoint(config.existed_ckpt)
  104. for param in net_with_loss.trainable_params():
  105. weights_name = param.name
  106. if weights_name not in weights:
  107. raise ValueError(f"Param {weights_name} is not found in ckpt file.")
  108. if isinstance(weights[weights_name], Parameter):
  109. param.default_input = weights[weights_name].default_input
  110. elif isinstance(weights[weights_name], Tensor):
  111. param.default_input = Tensor(weights[weights_name].asnumpy(), config.dtype)
  112. elif isinstance(weights[weights_name], np.ndarray):
  113. param.default_input = Tensor(weights[weights_name], config.dtype)
  114. else:
  115. param.default_input = weights[weights_name]
  116. else:
  117. for param in net_with_loss.trainable_params():
  118. name = param.name
  119. value = param.default_input
  120. if isinstance(value, Tensor):
  121. if name.endswith(".gamma"):
  122. param.default_input = one_weight(value.asnumpy().shape)
  123. elif name.endswith(".beta") or name.endswith(".bias"):
  124. param.default_input = zero_weight(value.asnumpy().shape)
  125. else:
  126. param.default_input = weight_variable(value.asnumpy().shape)
  127. dataset = pre_training_dataset if pre_training_dataset is not None \
  128. else fine_tune_dataset
  129. if dataset is None:
  130. raise ValueError("pre-training dataset or fine-tuning dataset must be provided one.")
  131. update_steps = dataset.get_repeat_count() * dataset.get_dataset_size()
  132. if config.lr_scheduler == "isr":
  133. lr = Tensor(square_root_schedule(lr=config.lr,
  134. update_num=update_steps,
  135. decay_start_step=config.decay_start_step,
  136. warmup_steps=config.warmup_steps,
  137. min_lr=config.min_lr), dtype=mstype.float32)
  138. elif config.lr_scheduler == "poly":
  139. lr = Tensor(polynomial_decay_scheduler(lr=config.lr,
  140. min_lr=config.min_lr,
  141. decay_steps=config.decay_steps,
  142. total_update_num=update_steps,
  143. warmup_steps=config.warmup_steps,
  144. power=config.poly_lr_scheduler_power), dtype=mstype.float32)
  145. else:
  146. lr = config.lr
  147. if config.optimizer.lower() == "adam":
  148. optimizer = Adam(net_with_loss.trainable_params(), lr, beta1=0.9, beta2=0.98)
  149. elif config.optimizer.lower() == "lamb":
  150. lr = BertLearningRate(decay_steps=12000, learning_rate=config.lr, end_learning_rate=config.min_lr,
  151. power=10.0, warmup_steps=config.warmup_steps)
  152. decay_params = list(filter(lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
  153. net_with_loss.trainable_params()))
  154. other_params = list(filter(lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.lower(),
  155. net_with_loss.trainable_params()))
  156. group_params = [{'params': decay_params, 'weight_decay': 0.01},
  157. {'params': other_params}]
  158. optimizer = Lamb(group_params, lr, eps=1e-6)
  159. elif config.optimizer.lower() == "momentum":
  160. optimizer = Momentum(net_with_loss.trainable_params(), lr, momentum=0.9)
  161. else:
  162. raise ValueError(f"optimizer only support `adam` and `momentum` now.")
  163. # loss scale.
  164. if config.loss_scale_mode == "dynamic":
  165. scale_manager = DynamicLossScaleManager(init_loss_scale=config.init_loss_scale,
  166. scale_factor=config.loss_scale_factor,
  167. scale_window=config.scale_window)
  168. else:
  169. scale_manager = FixedLossScaleManager(loss_scale=config.init_loss_scale, drop_overflow_update=True)
  170. net_with_grads = TransformerTrainOneStepWithLossScaleCell(network=net_with_loss, optimizer=optimizer,
  171. scale_update_cell=scale_manager.get_update_cell())
  172. net_with_grads.set_train(True)
  173. model = Model(net_with_grads)
  174. loss_monitor = LossCallBack(config)
  175. ckpt_config = CheckpointConfig(save_checkpoint_steps=config.save_ckpt_steps,
  176. keep_checkpoint_max=config.keep_ckpt_max)
  177. rank_size = os.getenv('RANK_SIZE')
  178. callbacks = [loss_monitor]
  179. if rank_size is not None and int(rank_size) > 1 and MultiAscend.get_rank() % 8 == 0:
  180. ckpt_callback = ModelCheckpoint(
  181. prefix=config.ckpt_prefix,
  182. directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))),
  183. config=ckpt_config)
  184. callbacks.append(ckpt_callback)
  185. if rank_size is None or int(rank_size) == 1:
  186. ckpt_callback = ModelCheckpoint(
  187. prefix=config.ckpt_prefix,
  188. directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))),
  189. config=ckpt_config)
  190. callbacks.append(ckpt_callback)
  191. print(f" | ALL SET, PREPARE TO TRAIN.")
  192. _train(model=model, config=config,
  193. pre_training_dataset=pre_training_dataset,
  194. fine_tune_dataset=fine_tune_dataset,
  195. test_dataset=test_dataset,
  196. callbacks=callbacks)
  197. def _setup_parallel_env(platform):
  198. context.reset_auto_parallel_context()
  199. MultiAscend.init()
  200. context.set_auto_parallel_context(
  201. parallel_mode=ParallelMode.DATA_PARALLEL,
  202. device_num=MultiAscend.get_group_size(),
  203. parameter_broadcast=True,
  204. mirror_mean=True
  205. )
  206. def train_parallel(config: TransformerConfig, platform: "Ascend"):
  207. """
  208. Train model with multi ascend chips.
  209. Args:
  210. config (TransformerConfig): Config for MASS model.
  211. """
  212. _setup_parallel_env(platform)
  213. print(f" | Starting training on {os.getenv('RANK_SIZE', None)} devices.")
  214. pre_train_dataset = load_dataset(
  215. data_files=config.pre_train_dataset,
  216. batch_size=config.batch_size, epoch_count=1,
  217. sink_mode=config.dataset_sink_mode,
  218. sink_step=config.dataset_sink_step,
  219. rank_size=MultiAscend.get_group_size(),
  220. rank_id=MultiAscend.get_rank()
  221. ) if config.pre_train_dataset else None
  222. fine_tune_dataset = load_dataset(
  223. data_files=config.fine_tune_dataset,
  224. batch_size=config.batch_size, epoch_count=1,
  225. sink_mode=config.dataset_sink_mode,
  226. sink_step=config.dataset_sink_step,
  227. rank_size=MultiAscend.get_group_size(),
  228. rank_id=MultiAscend.get_rank()
  229. ) if config.fine_tune_dataset else None
  230. test_dataset = load_dataset(
  231. data_files=config.test_dataset,
  232. batch_size=config.batch_size, epoch_count=1,
  233. sink_mode=config.dataset_sink_mode,
  234. sink_step=config.dataset_sink_step,
  235. rank_size=MultiAscend.get_group_size(),
  236. rank_id=MultiAscend.get_rank()
  237. ) if config.test_dataset else None
  238. _build_training_pipeline(config=config,
  239. pre_training_dataset=pre_train_dataset,
  240. fine_tune_dataset=fine_tune_dataset,
  241. test_dataset=test_dataset,
  242. platform=platform)
  243. def train_single(config: TransformerConfig, platform: "Ascend"):
  244. """
  245. Train model on single device.
  246. Args:
  247. config (TransformerConfig): Config for model.
  248. """
  249. print(" | Starting training on single device.")
  250. pre_train_dataset = load_dataset(data_files=config.pre_train_dataset,
  251. batch_size=config.batch_size,
  252. epoch_count=1,
  253. sink_mode=config.dataset_sink_mode,
  254. sink_step=config.dataset_sink_step) if config.pre_train_dataset else None
  255. fine_tune_dataset = load_dataset(data_files=config.fine_tune_dataset,
  256. batch_size=config.batch_size,
  257. epoch_count=1,
  258. sink_mode=config.dataset_sink_mode,
  259. sink_step=config.dataset_sink_step) if config.fine_tune_dataset else None
  260. test_dataset = load_dataset(data_files=config.test_dataset,
  261. batch_size=config.batch_size,
  262. epoch_count=1,
  263. sink_mode=config.dataset_sink_mode,
  264. sink_step=config.dataset_sink_step) if config.test_dataset else None
  265. _build_training_pipeline(config=config,
  266. pre_training_dataset=pre_train_dataset,
  267. fine_tune_dataset=fine_tune_dataset,
  268. test_dataset=test_dataset,
  269. platform=platform)
  270. def _check_args(config):
  271. if not os.path.exists(config):
  272. raise FileNotFoundError("`config` is not existed.")
  273. if not isinstance(config, str):
  274. raise ValueError("`config` must be type of str.")
  275. if __name__ == '__main__':
  276. args, _ = parser.parse_known_args()
  277. device_id = os.getenv('DEVICE_ID', None)
  278. if device_id is None:
  279. device_id = 0
  280. device_id = int(device_id)
  281. context.set_context(
  282. mode=context.GRAPH_MODE,
  283. device_target=args.platform,
  284. reserve_class_name_in_scope=False,
  285. device_id=device_id)
  286. _rank_size = os.getenv('RANK_SIZE')
  287. _check_args(args.config)
  288. _config = get_config(args.config)
  289. np.random.seed(_config.random_seed)
  290. context.set_context(save_graphs=_config.save_graphs)
  291. if _rank_size is not None and int(_rank_size) > 1:
  292. train_parallel(_config, args.platform)
  293. else:
  294. train_single(_config, args.platform)