|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344 |
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Train api."""
- import os
- import argparse
- import pickle
-
- import numpy as np
-
- import mindspore.common.dtype as mstype
- from mindspore.common.tensor import Tensor
- from mindspore.nn import Momentum
- from mindspore.nn.optim import Adam, Lamb
- from mindspore.train.model import Model
- from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
- from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
- from mindspore import context, ParallelMode, Parameter
- from mindspore.communication import management as MultiAscend
- from mindspore.train.serialization import load_checkpoint
-
- from config import TransformerConfig
- from src.dataset import load_dataset
- from src.transformer import TransformerNetworkWithLoss, TransformerTrainOneStepWithLossScaleCell
- from src.transformer.infer_mass import infer
- from src.utils import LossCallBack
- from src.utils import one_weight, zero_weight, weight_variable
- from src.utils import square_root_schedule
- from src.utils.lr_scheduler import polynomial_decay_scheduler, BertLearningRate
-
- parser = argparse.ArgumentParser(description='MASS train entry point.')
- parser.add_argument("--config", type=str, required=True, help="model config json file path.")
- parser.add_argument("--platform", type=str, required=True, help="model working platform.")
-
- def get_config(config):
- config = TransformerConfig.from_json_file(config)
- config.compute_type = mstype.float16
- config.dtype = mstype.float32
- return config
-
-
- def _train(model, config: TransformerConfig,
- pre_training_dataset=None, fine_tune_dataset=None, test_dataset=None,
- callbacks: list = None):
- """
- Train model.
-
- Args:
- model (Model): MindSpore model instance.
- config (TransformerConfig): Config of mass model.
- pre_training_dataset (Dataset): Pre-training dataset.
- fine_tune_dataset (Dataset): Fine-tune dataset.
- test_dataset (Dataset): Test dataset.
- callbacks (list): A list of callbacks.
- """
- callbacks = callbacks if callbacks else []
-
- if pre_training_dataset is not None:
- print(" | Start pre-training job.")
-
- if os.getenv("RANK_SIZE") is not None and int(os.getenv("RANK_SIZE")) > 1:
- print(f" | Rank {MultiAscend.get_rank()} Call model train.")
-
- model.train(config.epochs, pre_training_dataset,
- callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode,
- sink_size=config.dataset_sink_step)
-
- # Test the accuracy of the model.
- if test_dataset is not None:
- print(" | Start test job.")
- result = infer(_config)
- with open("validation_res_after_pre_training.bin", "wb") as f:
- pickle.dump(result, f, 1)
-
- if fine_tune_dataset is not None:
- print(" | Start fine-tuning job.")
-
- model.train(config.epochs, fine_tune_dataset,
- callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode,
- sink_size=config.dataset_sink_step)
-
- # Test the accuracy of the model.
- if test_dataset is not None:
- print(" | Start test job.")
- result = infer(_config)
- with open("validation_res_after_pre_training.bin", "wb") as f:
- pickle.dump(result, f, 1)
-
-
- def _build_training_pipeline(config: TransformerConfig,
- pre_training_dataset=None,
- fine_tune_dataset=None,
- test_dataset=None,
- platform="Ascend"):
- """
- Build training pipeline.
-
- Args:
- config (TransformerConfig): Config of mass model.
- pre_training_dataset (Dataset): Pre-training dataset.
- fine_tune_dataset (Dataset): Fine-tune dataset.
- test_dataset (Dataset): Test dataset.
- """
- net_with_loss = TransformerNetworkWithLoss(config, is_training=True)
- net_with_loss.init_parameters_data()
-
- if config.existed_ckpt:
- if config.existed_ckpt.endswith(".npz"):
- weights = np.load(config.existed_ckpt)
- else:
- weights = load_checkpoint(config.existed_ckpt)
- for param in net_with_loss.trainable_params():
- weights_name = param.name
- if weights_name not in weights:
- raise ValueError(f"Param {weights_name} is not found in ckpt file.")
-
- if isinstance(weights[weights_name], Parameter):
- param.default_input = weights[weights_name].default_input
- elif isinstance(weights[weights_name], Tensor):
- param.default_input = Tensor(weights[weights_name].asnumpy(), config.dtype)
- elif isinstance(weights[weights_name], np.ndarray):
- param.default_input = Tensor(weights[weights_name], config.dtype)
- else:
- param.default_input = weights[weights_name]
- else:
- for param in net_with_loss.trainable_params():
- name = param.name
- value = param.default_input
- if isinstance(value, Tensor):
- if name.endswith(".gamma"):
- param.default_input = one_weight(value.asnumpy().shape)
- elif name.endswith(".beta") or name.endswith(".bias"):
- param.default_input = zero_weight(value.asnumpy().shape)
- else:
- param.default_input = weight_variable(value.asnumpy().shape)
-
- dataset = pre_training_dataset if pre_training_dataset is not None \
- else fine_tune_dataset
-
- if dataset is None:
- raise ValueError("pre-training dataset or fine-tuning dataset must be provided one.")
-
- update_steps = dataset.get_repeat_count() * dataset.get_dataset_size()
- if config.lr_scheduler == "isr":
- lr = Tensor(square_root_schedule(lr=config.lr,
- update_num=update_steps,
- decay_start_step=config.decay_start_step,
- warmup_steps=config.warmup_steps,
- min_lr=config.min_lr), dtype=mstype.float32)
- elif config.lr_scheduler == "poly":
- lr = Tensor(polynomial_decay_scheduler(lr=config.lr,
- min_lr=config.min_lr,
- decay_steps=config.decay_steps,
- total_update_num=update_steps,
- warmup_steps=config.warmup_steps,
- power=config.poly_lr_scheduler_power), dtype=mstype.float32)
- else:
- lr = config.lr
-
- if config.optimizer.lower() == "adam":
- optimizer = Adam(net_with_loss.trainable_params(), lr, beta1=0.9, beta2=0.98)
- elif config.optimizer.lower() == "lamb":
- lr = BertLearningRate(decay_steps=12000, learning_rate=config.lr, end_learning_rate=config.min_lr,
- power=10.0, warmup_steps=config.warmup_steps)
- decay_params = list(filter(lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
- net_with_loss.trainable_params()))
- other_params = list(filter(lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.lower(),
- net_with_loss.trainable_params()))
- group_params = [{'params': decay_params, 'weight_decay': 0.01},
- {'params': other_params}]
-
- optimizer = Lamb(group_params, lr, eps=1e-6)
- elif config.optimizer.lower() == "momentum":
- optimizer = Momentum(net_with_loss.trainable_params(), lr, momentum=0.9)
- else:
- raise ValueError(f"optimizer only support `adam` and `momentum` now.")
-
- # loss scale.
- if config.loss_scale_mode == "dynamic":
- scale_manager = DynamicLossScaleManager(init_loss_scale=config.init_loss_scale,
- scale_factor=config.loss_scale_factor,
- scale_window=config.scale_window)
- else:
- scale_manager = FixedLossScaleManager(loss_scale=config.init_loss_scale, drop_overflow_update=True)
- net_with_grads = TransformerTrainOneStepWithLossScaleCell(network=net_with_loss, optimizer=optimizer,
- scale_update_cell=scale_manager.get_update_cell())
- net_with_grads.set_train(True)
- model = Model(net_with_grads)
- loss_monitor = LossCallBack(config)
- ckpt_config = CheckpointConfig(save_checkpoint_steps=config.save_ckpt_steps,
- keep_checkpoint_max=config.keep_ckpt_max)
-
- rank_size = os.getenv('RANK_SIZE')
- callbacks = [loss_monitor]
- if rank_size is not None and int(rank_size) > 1 and MultiAscend.get_rank() % 8 == 0:
- ckpt_callback = ModelCheckpoint(
- prefix=config.ckpt_prefix,
- directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))),
- config=ckpt_config)
- callbacks.append(ckpt_callback)
-
- if rank_size is None or int(rank_size) == 1:
- ckpt_callback = ModelCheckpoint(
- prefix=config.ckpt_prefix,
- directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))),
- config=ckpt_config)
- callbacks.append(ckpt_callback)
-
- print(f" | ALL SET, PREPARE TO TRAIN.")
- _train(model=model, config=config,
- pre_training_dataset=pre_training_dataset,
- fine_tune_dataset=fine_tune_dataset,
- test_dataset=test_dataset,
- callbacks=callbacks)
-
-
- def _setup_parallel_env(platform):
- context.reset_auto_parallel_context()
- MultiAscend.init()
- context.set_auto_parallel_context(
- parallel_mode=ParallelMode.DATA_PARALLEL,
- device_num=MultiAscend.get_group_size(),
- parameter_broadcast=True,
- mirror_mean=True
- )
-
-
- def train_parallel(config: TransformerConfig, platform: "Ascend"):
- """
- Train model with multi ascend chips.
-
- Args:
- config (TransformerConfig): Config for MASS model.
- """
- _setup_parallel_env(platform)
-
- print(f" | Starting training on {os.getenv('RANK_SIZE', None)} devices.")
-
- pre_train_dataset = load_dataset(
- data_files=config.pre_train_dataset,
- batch_size=config.batch_size, epoch_count=1,
- sink_mode=config.dataset_sink_mode,
- sink_step=config.dataset_sink_step,
- rank_size=MultiAscend.get_group_size(),
- rank_id=MultiAscend.get_rank()
- ) if config.pre_train_dataset else None
- fine_tune_dataset = load_dataset(
- data_files=config.fine_tune_dataset,
- batch_size=config.batch_size, epoch_count=1,
- sink_mode=config.dataset_sink_mode,
- sink_step=config.dataset_sink_step,
- rank_size=MultiAscend.get_group_size(),
- rank_id=MultiAscend.get_rank()
- ) if config.fine_tune_dataset else None
- test_dataset = load_dataset(
- data_files=config.test_dataset,
- batch_size=config.batch_size, epoch_count=1,
- sink_mode=config.dataset_sink_mode,
- sink_step=config.dataset_sink_step,
- rank_size=MultiAscend.get_group_size(),
- rank_id=MultiAscend.get_rank()
- ) if config.test_dataset else None
-
- _build_training_pipeline(config=config,
- pre_training_dataset=pre_train_dataset,
- fine_tune_dataset=fine_tune_dataset,
- test_dataset=test_dataset,
- platform=platform)
-
-
- def train_single(config: TransformerConfig, platform: "Ascend"):
- """
- Train model on single device.
-
- Args:
- config (TransformerConfig): Config for model.
- """
- print(" | Starting training on single device.")
- pre_train_dataset = load_dataset(data_files=config.pre_train_dataset,
- batch_size=config.batch_size,
- epoch_count=1,
- sink_mode=config.dataset_sink_mode,
- sink_step=config.dataset_sink_step) if config.pre_train_dataset else None
- fine_tune_dataset = load_dataset(data_files=config.fine_tune_dataset,
- batch_size=config.batch_size,
- epoch_count=1,
- sink_mode=config.dataset_sink_mode,
- sink_step=config.dataset_sink_step) if config.fine_tune_dataset else None
- test_dataset = load_dataset(data_files=config.test_dataset,
- batch_size=config.batch_size,
- epoch_count=1,
- sink_mode=config.dataset_sink_mode,
- sink_step=config.dataset_sink_step) if config.test_dataset else None
-
- _build_training_pipeline(config=config,
- pre_training_dataset=pre_train_dataset,
- fine_tune_dataset=fine_tune_dataset,
- test_dataset=test_dataset,
- platform=platform)
-
-
- def _check_args(config):
- if not os.path.exists(config):
- raise FileNotFoundError("`config` is not existed.")
- if not isinstance(config, str):
- raise ValueError("`config` must be type of str.")
-
-
- if __name__ == '__main__':
- args, _ = parser.parse_known_args()
-
- device_id = os.getenv('DEVICE_ID', None)
- if device_id is None:
- device_id = 0
- device_id = int(device_id)
- context.set_context(
- mode=context.GRAPH_MODE,
- device_target=args.platform,
- reserve_class_name_in_scope=False,
- device_id=device_id)
-
- _rank_size = os.getenv('RANK_SIZE')
-
- _check_args(args.config)
- _config = get_config(args.config)
-
- np.random.seed(_config.random_seed)
- context.set_context(save_graphs=_config.save_graphs)
-
- if _rank_size is not None and int(_rank_size) > 1:
- train_parallel(_config, args.platform)
- else:
- train_single(_config, args.platform)
|