You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Train api."""
  16. import os
  17. import argparse
  18. import numpy as np
  19. import mindspore.common.dtype as mstype
  20. from mindspore.common.tensor import Tensor
  21. from mindspore.nn import Momentum
  22. from mindspore.nn.optim import Lamb
  23. from mindspore.train.model import Model
  24. from mindspore.train.loss_scale_manager import DynamicLossScaleManager
  25. from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, SummaryCollector, TimeMonitor
  26. from mindspore import context, Parameter
  27. from mindspore.context import ParallelMode
  28. from mindspore.communication import management as MultiAscend
  29. from mindspore.train.serialization import load_checkpoint
  30. from mindspore.common import set_seed
  31. from config import GNMTConfig
  32. from src.dataset import load_dataset
  33. from src.gnmt_model import GNMTNetworkWithLoss, GNMTTrainOneStepWithLossScaleCell
  34. from src.utils import LossCallBack
  35. from src.utils import one_weight, weight_variable
  36. from src.utils.lr_scheduler import square_root_schedule, polynomial_decay_scheduler, Warmup_MultiStepLR_scheduler
  37. from src.utils.optimizer import Adam
  38. parser = argparse.ArgumentParser(description='GNMT train entry point.')
  39. parser.add_argument("--config", type=str, required=True, help="model config json file path.")
  40. parser.add_argument("--pre_train_dataset", type=str, required=True, help="pre-train dataset address.")
  41. device_id = os.getenv('DEVICE_ID', None)
  42. if device_id is None:
  43. raise RuntimeError("`DEVICE_ID` can not be None.")
  44. device_id = int(device_id)
  45. context.set_context(
  46. mode=context.GRAPH_MODE,
  47. save_graphs=False,
  48. device_target="Ascend",
  49. reserve_class_name_in_scope=True,
  50. device_id=device_id)
  51. def get_config(config):
  52. config = GNMTConfig.from_json_file(config)
  53. config.compute_type = mstype.float16
  54. config.dtype = mstype.float32
  55. return config
  56. def _train(model, config: GNMTConfig,
  57. pre_training_dataset=None, fine_tune_dataset=None, test_dataset=None,
  58. callbacks: list = None):
  59. """
  60. Train model.
  61. Args:
  62. model (Model): MindSpore model instance.
  63. config (GNMTConfig): Config of mass model.
  64. pre_training_dataset (Dataset): Pre-training dataset.
  65. fine_tune_dataset (Dataset): Fine-tune dataset.
  66. test_dataset (Dataset): Test dataset.
  67. callbacks (list): A list of callbacks.
  68. """
  69. callbacks = callbacks if callbacks else []
  70. if pre_training_dataset is not None:
  71. print(" | Start pre-training job.")
  72. epoch_size = pre_training_dataset.get_repeat_count()
  73. print("epoch size ", epoch_size)
  74. if os.getenv("RANK_SIZE") is not None and int(os.getenv("RANK_SIZE")) > 1:
  75. print(f" | Rank {MultiAscend.get_rank()} Call model train.")
  76. model.train(config.epochs, pre_training_dataset,
  77. callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode)
  78. if fine_tune_dataset is not None:
  79. print(" | Start fine-tuning job.")
  80. epoch_size = fine_tune_dataset.get_repeat_count()
  81. model.train(config.epochs, fine_tune_dataset,
  82. callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode)
  83. def _load_checkpoint_to_net(config, network):
  84. """load parameters to network from checkpoint."""
  85. if config.existed_ckpt:
  86. if config.existed_ckpt.endswith(".npz"):
  87. weights = np.load(config.existed_ckpt)
  88. else:
  89. weights = load_checkpoint(config.existed_ckpt)
  90. for param in network.trainable_params():
  91. weights_name = param.name
  92. if weights_name not in weights:
  93. raise ValueError(f"Param {weights_name} is not found in ckpt file.")
  94. if isinstance(weights[weights_name], Parameter):
  95. param.set_data(weights[weights_name].data)
  96. elif isinstance(weights[weights_name], Tensor):
  97. param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
  98. elif isinstance(weights[weights_name], np.ndarray):
  99. param.set_data(Tensor(weights[weights_name], config.dtype))
  100. else:
  101. param.set_data(weights[weights_name])
  102. else:
  103. for param in network.trainable_params():
  104. name = param.name
  105. value = param.data
  106. if isinstance(value, Tensor):
  107. if name.endswith(".gamma"):
  108. param.set_data(one_weight(value.asnumpy().shape))
  109. elif name.endswith(".beta") or name.endswith(".bias"):
  110. if param.data.dtype == "Float32":
  111. param.set_data((weight_variable(value.asnumpy().shape).astype(np.float32)))
  112. elif param.data.dtype == "Float16":
  113. param.set_data((weight_variable(value.asnumpy().shape).astype(np.float16)))
  114. else:
  115. if param.data.dtype == "Float32":
  116. param.set_data(Tensor(weight_variable(value.asnumpy().shape).astype(np.float32)))
  117. elif param.data.dtype == "Float16":
  118. param.set_data(Tensor(weight_variable(value.asnumpy().shape).astype(np.float16)))
  119. def _get_lr(config, update_steps):
  120. """generate learning rate."""
  121. if config.lr_scheduler == "isr":
  122. lr = Tensor(square_root_schedule(lr=config.lr,
  123. update_num=update_steps,
  124. decay_start_step=config.decay_start_step,
  125. warmup_steps=config.warmup_steps,
  126. min_lr=config.min_lr), dtype=mstype.float32)
  127. elif config.lr_scheduler == "poly":
  128. lr = Tensor(polynomial_decay_scheduler(lr=config.lr,
  129. min_lr=config.min_lr,
  130. decay_steps=config.decay_steps,
  131. total_update_num=update_steps,
  132. warmup_steps=config.warmup_steps,
  133. power=config.lr_scheduler_power), dtype=mstype.float32)
  134. elif config.lr_scheduler == "WarmupMultiStepLR":
  135. lr = Tensor(Warmup_MultiStepLR_scheduler(base_lr=config.lr,
  136. total_update_num=update_steps,
  137. warmup_steps=config.warmup_steps,
  138. remain_steps=config.warmup_lr_remain_steps,
  139. decay_interval=config.warmup_lr_decay_interval,
  140. decay_steps=config.decay_steps,
  141. decay_factor=config.lr_scheduler_power), dtype=mstype.float32)
  142. else:
  143. lr = config.lr
  144. return lr
  145. def _get_optimizer(config, network, lr):
  146. """get gnmt optimizer, support Adam, Lamb, Momentum."""
  147. if config.optimizer.lower() == "adam":
  148. optimizer = Adam(network.trainable_params(), lr, beta1=0.9, beta2=0.98)
  149. elif config.optimizer.lower() == "lamb":
  150. optimizer = Lamb(network.trainable_params(), learning_rate=lr,
  151. eps=1e-6)
  152. elif config.optimizer.lower() == "momentum":
  153. optimizer = Momentum(network.trainable_params(), lr, momentum=0.9)
  154. else:
  155. raise ValueError(f"optimizer only support `adam` and `momentum` now.")
  156. return optimizer
  157. def _build_training_pipeline(config: GNMTConfig,
  158. pre_training_dataset=None,
  159. fine_tune_dataset=None,
  160. test_dataset=None):
  161. """
  162. Build training pipeline.
  163. Args:
  164. config (GNMTConfig): Config of mass model.
  165. pre_training_dataset (Dataset): Pre-training dataset.
  166. fine_tune_dataset (Dataset): Fine-tune dataset.
  167. test_dataset (Dataset): Test dataset.
  168. """
  169. net_with_loss = GNMTNetworkWithLoss(config, is_training=True, use_one_hot_embeddings=True)
  170. net_with_loss.init_parameters_data()
  171. _load_checkpoint_to_net(config, net_with_loss)
  172. dataset = pre_training_dataset if pre_training_dataset is not None \
  173. else fine_tune_dataset
  174. if dataset is None:
  175. raise ValueError("pre-training dataset or fine-tuning dataset must be provided one.")
  176. update_steps = config.epochs * dataset.get_dataset_size()
  177. lr = _get_lr(config, update_steps)
  178. optimizer = _get_optimizer(config, net_with_loss, lr)
  179. # Dynamic loss scale.
  180. scale_manager = DynamicLossScaleManager(init_loss_scale=config.init_loss_scale,
  181. scale_factor=config.loss_scale_factor,
  182. scale_window=config.scale_window)
  183. net_with_grads = GNMTTrainOneStepWithLossScaleCell(
  184. network=net_with_loss, optimizer=optimizer,
  185. scale_update_cell=scale_manager.get_update_cell()
  186. )
  187. net_with_grads.set_train(True)
  188. model = Model(net_with_grads)
  189. loss_monitor = LossCallBack(config)
  190. dataset_size = dataset.get_dataset_size()
  191. time_cb = TimeMonitor(data_size=dataset_size)
  192. ckpt_config = CheckpointConfig(save_checkpoint_steps=config.save_ckpt_steps,
  193. keep_checkpoint_max=config.keep_ckpt_max)
  194. rank_size = os.getenv('RANK_SIZE')
  195. callbacks = [time_cb, loss_monitor]
  196. if rank_size is not None and int(rank_size) > 1 and MultiAscend.get_rank() % 8 == 0:
  197. ckpt_callback = ModelCheckpoint(
  198. prefix=config.ckpt_prefix,
  199. directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))),
  200. config=ckpt_config)
  201. callbacks.append(ckpt_callback)
  202. summary_callback = SummaryCollector(summary_dir="./summary", collect_freq=50)
  203. callbacks.append(summary_callback)
  204. if rank_size is None or int(rank_size) == 1:
  205. ckpt_callback = ModelCheckpoint(
  206. prefix=config.ckpt_prefix,
  207. directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))),
  208. config=ckpt_config)
  209. callbacks.append(ckpt_callback)
  210. summary_callback = SummaryCollector(summary_dir="./summary", collect_freq=50)
  211. callbacks.append(summary_callback)
  212. print(f" | ALL SET, PREPARE TO TRAIN.")
  213. _train(model=model, config=config,
  214. pre_training_dataset=pre_training_dataset,
  215. fine_tune_dataset=fine_tune_dataset,
  216. test_dataset=test_dataset,
  217. callbacks=callbacks)
  218. def _setup_parallel_env():
  219. context.reset_auto_parallel_context()
  220. MultiAscend.init()
  221. context.set_auto_parallel_context(
  222. parallel_mode=ParallelMode.DATA_PARALLEL,
  223. device_num=MultiAscend.get_group_size(),
  224. gradients_mean=True
  225. )
  226. def train_parallel(config: GNMTConfig):
  227. """
  228. Train model with multi ascend chips.
  229. Args:
  230. config (GNMTConfig): Config for MASS model.
  231. """
  232. _setup_parallel_env()
  233. print(f" | Starting training on {os.getenv('RANK_SIZE', None)} devices.")
  234. pre_train_dataset = load_dataset(
  235. data_files=config.pre_train_dataset,
  236. batch_size=config.batch_size,
  237. sink_mode=config.dataset_sink_mode,
  238. rank_size=MultiAscend.get_group_size(),
  239. rank_id=MultiAscend.get_rank()
  240. ) if config.pre_train_dataset else None
  241. fine_tune_dataset = load_dataset(
  242. data_files=config.fine_tune_dataset,
  243. batch_size=config.batch_size,
  244. sink_mode=config.dataset_sink_mode,
  245. rank_size=MultiAscend.get_group_size(),
  246. rank_id=MultiAscend.get_rank()
  247. ) if config.fine_tune_dataset else None
  248. test_dataset = load_dataset(
  249. data_files=config.test_dataset,
  250. batch_size=config.batch_size,
  251. sink_mode=config.dataset_sink_mode,
  252. rank_size=MultiAscend.get_group_size(),
  253. rank_id=MultiAscend.get_rank()
  254. ) if config.test_dataset else None
  255. _build_training_pipeline(config=config,
  256. pre_training_dataset=pre_train_dataset,
  257. fine_tune_dataset=fine_tune_dataset,
  258. test_dataset=test_dataset)
  259. def train_single(config: GNMTConfig):
  260. """
  261. Train model on single device.
  262. Args:
  263. config (GNMTConfig): Config for model.
  264. """
  265. print(" | Starting training on single device.")
  266. pre_train_dataset = load_dataset(data_files=config.pre_train_dataset,
  267. batch_size=config.batch_size,
  268. sink_mode=config.dataset_sink_mode) if config.pre_train_dataset else None
  269. fine_tune_dataset = load_dataset(data_files=config.fine_tune_dataset,
  270. batch_size=config.batch_size,
  271. sink_mode=config.dataset_sink_mode) if config.fine_tune_dataset else None
  272. test_dataset = load_dataset(data_files=config.test_dataset,
  273. batch_size=config.batch_size,
  274. sink_mode=config.dataset_sink_mode) if config.test_dataset else None
  275. _build_training_pipeline(config=config,
  276. pre_training_dataset=pre_train_dataset,
  277. fine_tune_dataset=fine_tune_dataset,
  278. test_dataset=test_dataset)
  279. def _check_args(config):
  280. if not os.path.exists(config):
  281. raise FileNotFoundError("`config` is not existed.")
  282. if not isinstance(config, str):
  283. raise ValueError("`config` must be type of str.")
  284. if __name__ == '__main__':
  285. _rank_size = os.getenv('RANK_SIZE')
  286. args, _ = parser.parse_known_args()
  287. _check_args(args.config)
  288. _config = get_config(args.config)
  289. _config.pre_train_dataset = args.pre_train_dataset
  290. set_seed(_config.random_seed)
  291. if _rank_size is not None and int(_rank_size) > 1:
  292. train_parallel(_config)
  293. else:
  294. train_single(_config)