Merge pull request !4412 from linqingke/masstags/v0.7.0-beta
| @@ -57,9 +57,6 @@ The overall network architecture of MASS is shown below, which is Transformer(Va | |||
| MASS is consisted of 6-layer encoder and 6-layer decoder with 1024 embedding/hidden size, and 4096 intermediate size between feed forward network which has two full connection layers. | |||
|  | |||
| # Dataset | |||
| Dataset used: | |||
| @@ -124,7 +121,8 @@ MASS script and code structure are as follow: | |||
| │ ├──all.bpe.codes // BPE codes table(this file should be generated by user). | |||
| │ ├──all_en.dict.bin // Learned vocabulary file(this file should be generated by user). | |||
| ├── scripts | |||
| │ ├──run.sh // Train & evaluate model script. | |||
| │ ├──run_ascend.sh // Ascend train & evaluate model script. | |||
| │ ├──run_gpu.sh // GPU train & evaluate model script. | |||
| │ ├──learn_subword.sh // Learn BPE codes. | |||
| │ ├──stop_training.sh // Stop training. | |||
| ├── requirements.txt // Requirements of third party package. | |||
| @@ -329,18 +327,24 @@ Almost all of the options and arguments needed could be assigned conveniently, i | |||
| For more detailed information about the attributes, refer to the file `config/config.py`. | |||
| ## Training & Evaluation process | |||
| For training a model, the shell script `run.sh` is all you need. In this scripts, the environment variable is set and the training script `train.py` under `mass` is executed. | |||
| For training a model, the shell script `run_ascend.sh` or `run_gpu.sh` is all you need. In this scripts, the environment variable is set and the training script `train.py` under `mass` is executed. | |||
| You may start a task training with single device or multiple devices by assigning the options and run the command in bash: | |||
| ```bash | |||
| sh run.sh [--options] | |||
| Ascend: | |||
| ```ascend | |||
| sh run_ascend.sh [--options] | |||
| ``` | |||
| GPU: | |||
| ```gpu | |||
| sh run_gpu.sh [--options] | |||
| ``` | |||
| The usage is shown as bellow: | |||
| The usage of `run_ascend.sh` is shown as bellow: | |||
| ```text | |||
| Usage: run.sh [-h, --help] [-t, --task <CHAR>] [-n, --device_num <N>] | |||
| [-i, --device_id <N>] [-j, --hccl_json <FILE>] | |||
| [-c, --config <FILE>] [-o, --output <FILE>] | |||
| [-v, --vocab <FILE>] | |||
| Usage: run_ascend.sh [-h, --help] [-t, --task <CHAR>] [-n, --device_num <N>] | |||
| [-i, --device_id <N>] [-j, --hccl_json <FILE>] | |||
| [-c, --config <FILE>] [-o, --output <FILE>] | |||
| [-v, --vocab <FILE>] | |||
| options: | |||
| -h, --help show usage | |||
| @@ -350,20 +354,49 @@ options: | |||
| -j, --hccl_json rank table file used for training with multiple devices: FILE. | |||
| -c, --config configuration file as shown in the path 'mass/config': FILE. | |||
| -o, --output assign output file of inference: FILE. | |||
| -v, --vocab set the vocabulary" | |||
| -v, --vocab set the vocabulary. | |||
| -m, --metric set the metric. | |||
| ``` | |||
| Notes: Be sure to assign the hccl_json file while running a distributed-training. | |||
| The usage of `run_gpu.sh` is shown as bellow: | |||
| ```text | |||
| Usage: run_gpu.sh [-h, --help] [-t, --task <CHAR>] [-n, --device_num <N>] | |||
| [-i, --device_id <N>] [-c, --config <FILE>] | |||
| [-o, --output <FILE>] [-v, --vocab <FILE>] | |||
| options: | |||
| -h, --help show usage | |||
| -t, --task select task: CHAR, 't' for train and 'i' for inference". | |||
| -n, --device_num device number used for training: N, default is 1. | |||
| -i, --device_id device id used for training with single device: N, 0<=N<=7, default is 0. | |||
| -c, --config configuration file as shown in the path 'mass/config': FILE. | |||
| -o, --output assign output file of inference: FILE. | |||
| -v, --vocab set the vocabulary. | |||
| -m, --metric set the metric. | |||
| ``` | |||
| The command followed shows a example for training with 2 devices. | |||
| ```bash | |||
| sh run.sh --task t --device_num 2 --hccl_json /{path}/rank_table.json --config /{path}/config.json | |||
| Ascend: | |||
| ```ascend | |||
| sh run_ascend.sh --task t --device_num 2 --hccl_json /{path}/rank_table.json --config /{path}/config.json | |||
| ``` | |||
| ps. Discontinuous device id is not supported in `run.sh` at present, device id in `rank_table.json` must start from 0. | |||
| ps. Discontinuous device id is not supported in `run_ascend.sh` at present, device id in `rank_table.json` must start from 0. | |||
| GPU: | |||
| ```gpu | |||
| sh run_gpu.sh --task t --device_num 2 --config /{path}/config.json | |||
| ``` | |||
| If use a single chip, it would be like this: | |||
| ```bash | |||
| sh run.sh --task t --device_num 1 --device_id 0 --config /{path}/config.json | |||
| Ascend: | |||
| ```ascend | |||
| sh run_ascend.sh --task t --device_num 1 --device_id 0 --config /{path}/config.json | |||
| ``` | |||
| GPU: | |||
| ```gpu | |||
| sh run_gpu.sh --task t --device_num 1 --device_id 0 --config /{path}/config.json | |||
| ``` | |||
| @@ -441,9 +474,6 @@ During testing, we use the fine-turned model to predict the result, and adopt a | |||
| get the most possible prediction results. | |||
|  | |||
| ## Performance | |||
| ### Results | |||
| @@ -543,11 +573,18 @@ For pre-training a model, config the options in `config.json` firstly: | |||
| - Set other arguments including dataset configurations and network configurations. | |||
| - If you have a trained model already, assign the `existed_ckpt` to the checkpoint file. | |||
| Run the shell script `run.sh` as followed: | |||
| If you use the ascend chip, run the shell script `run_ascend.sh` as followed: | |||
| ```bash | |||
| sh run.sh -t t -n 1 -i 1 -c /mass/config/config.json | |||
| ```ascend | |||
| sh run_ascend.sh -t t -n 1 -i 1 -c /mass/config/config.json | |||
| ``` | |||
| You can also run the shell script `run_gpu.sh` on gpu as followed: | |||
| ```gpu | |||
| sh run_gpu.sh -t t -n 1 -i 1 -c /mass/config/config.json | |||
| ``` | |||
| Get the log and output files under the path `./train_mass_*/`, and the model file under the path assigned in the `config/config.json` file. | |||
| ## Fine-tuning | |||
| @@ -558,10 +595,18 @@ For fine-tuning a model, config the options in `config.json` firstly: | |||
| - Assign the `ckpt_prefix` and `ckpt_path` under `checkpoint_path` node to save the model files. | |||
| - Set other arguments including dataset configurations and network configurations. | |||
| Run the shell script `run.sh` as followed: | |||
| ```bash | |||
| sh run.sh -t t -n 1 -i 1 -c config/config.json | |||
| If you use the ascend chip, run the shell script `run_ascend.sh` as followed: | |||
| ```ascend | |||
| sh run_ascend.sh -t t -n 1 -i 1 -c config/config.json | |||
| ``` | |||
| You can also run the shell script `run_gpu.sh` on gpu as followed: | |||
| ```gpu | |||
| sh run_gpu.sh -t t -n 1 -i 1 -c config/config.json | |||
| ``` | |||
| Get the log and output files under the path `./train_mass_*/`, and the model file under the path assigned in the `config/config.json` file. | |||
| ## Inference | |||
| @@ -573,10 +618,16 @@ For inference, config the options in `config.json` firstly: | |||
| - Assign the `ckpt_prefix` and `ckpt_path` under `checkpoint_path` node to save the model files. | |||
| - Set other arguments including dataset configurations and network configurations. | |||
| Run the shell script `run.sh` as followed: | |||
| If you use the ascend chip, run the shell script `run_ascend.sh` as followed: | |||
| ```bash | |||
| sh run.sh -t i -n 1 -i 1 -c config/config.json -o {outputfile} | |||
| sh run_ascend.sh -t i -n 1 -i 1 -c config/config.json -o {outputfile} | |||
| ``` | |||
| You can also run the shell script `run_gpu.sh` on gpu as followed: | |||
| ```gpu | |||
| sh run_gpu.sh -t i -n 1 -i 1 -c config/config.json -o {outputfile} | |||
| ``` | |||
| # Description of random situation | |||
| @@ -13,10 +13,12 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Evaluation api.""" | |||
| import os | |||
| import argparse | |||
| import pickle | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore import context | |||
| from config import TransformerConfig | |||
| from src.transformer import infer, infer_ppl | |||
| @@ -32,6 +34,8 @@ parser.add_argument("--output", type=str, required=True, | |||
| help="Result file path.") | |||
| parser.add_argument("--metric", type=str, default='rouge', | |||
| help='Set eval method.') | |||
| parser.add_argument("--platform", type=str, required=True, | |||
| help="model working platform.") | |||
| def get_config(config): | |||
| @@ -46,6 +50,16 @@ if __name__ == '__main__': | |||
| vocab = Dictionary.load_from_persisted_dict(args.vocab) | |||
| _config = get_config(args.config) | |||
| device_id = os.getenv('DEVICE_ID', None) | |||
| if device_id is None: | |||
| device_id = 0 | |||
| device_id = int(device_id) | |||
| context.set_context( | |||
| mode=context.GRAPH_MODE, | |||
| device_target=args.platform, | |||
| reserve_class_name_in_scope=False, | |||
| device_id=device_id) | |||
| if args.metric == 'rouge': | |||
| result = infer(_config) | |||
| else: | |||
| @@ -165,10 +165,10 @@ do | |||
| echo $task | |||
| if [ "$task" == "train" ] | |||
| then | |||
| python train.py --config ${configurations##*/} >>log.log 2>&1 & | |||
| python train.py --config ${configurations##*/} --platform Ascend >>log.log 2>&1 & | |||
| elif [ "$task" == "infer" ] | |||
| then | |||
| python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} --metric ${metric} >>log_infer.log 2>&1 & | |||
| python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} --metric ${metric} --platform Ascend >>log_infer.log 2>&1 & | |||
| fi | |||
| cd ../ | |||
| done | |||
| @@ -0,0 +1,157 @@ | |||
| #!/usr/bin/env bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| export DEVICE_ID=0 | |||
| export RANK_ID=0 | |||
| export RANK_SIZE=1 | |||
| options=`getopt -u -o ht:n:i::o:v:m: -l help,task:,device_num:,device_id:,config:,output:,vocab:,metric: -- "$@"` | |||
| eval set -- "$options" | |||
| echo $options | |||
| echo_help() | |||
| { | |||
| echo "Usage:" | |||
| echo "bash train.sh [-h] [-t t|i] [-n N] [-i N] [-j FILE] [-c FILE] [-o FILE] [-v FILE]" | |||
| echo "options:" | |||
| echo " -h --help show usage" | |||
| echo " -t --task select task, 't' for training and 'i' for inference" | |||
| echo " -n --device_num training with N devices" | |||
| echo " -i --device_id training with device i" | |||
| echo " -c --config set the configuration file" | |||
| echo " -o --output set the output file of inference" | |||
| echo " -v --vocab set the vocabulary" | |||
| echo " -m --metric set the metric" | |||
| } | |||
| set_device_id() | |||
| { | |||
| while [ -n "$1" ] | |||
| do | |||
| if [[ "$1" == "-i" || "$1" == "--device_id" ]] | |||
| then | |||
| if [[ $2 -ge 0 && $2 -le 7 ]] | |||
| then | |||
| export DEVICE_ID=$2 | |||
| fi | |||
| break | |||
| fi | |||
| shift | |||
| done | |||
| } | |||
| while [ -n "$1" ] | |||
| do | |||
| case "$1" in | |||
| -h|--help) | |||
| echo_help | |||
| shift | |||
| ;; | |||
| -t|--task) | |||
| echo "task:" | |||
| if [ "$2" == "t" ] | |||
| then | |||
| task=train | |||
| elif [ "$2" == "i" ] | |||
| then | |||
| task=infer | |||
| fi | |||
| shift 2 | |||
| ;; | |||
| -n|--device_num) | |||
| echo "device_num" | |||
| if [ $2 -eq 1 ] | |||
| then | |||
| set_device_id $options | |||
| elif [ $2 -gt 1 ] | |||
| then | |||
| export RANK_SIZE=$2 | |||
| fi | |||
| shift 2 | |||
| ;; | |||
| -i|--device_id) | |||
| echo "set device id" | |||
| export DEVICE_ID=$2 | |||
| shift 2 | |||
| ;; | |||
| -c|--config) | |||
| echo "config"; | |||
| configurations=$2 | |||
| shift 2 | |||
| ;; | |||
| -o|--output) | |||
| echo "output"; | |||
| output=$2 | |||
| shift 2 | |||
| ;; | |||
| -v|--vocab) | |||
| echo "vocab"; | |||
| vocab=$2 | |||
| shift 2 | |||
| ;; | |||
| -m|--metric) | |||
| echo "metric"; | |||
| metric=$2 | |||
| shift 2 | |||
| ;; | |||
| --) | |||
| shift | |||
| break | |||
| ;; | |||
| *) | |||
| shift | |||
| ;; | |||
| esac | |||
| done | |||
| file_path=$(cd "$(dirname $0)" || exit; pwd) | |||
| if [ $RANK_SIZE -gt 1 ] | |||
| then | |||
| echo "Working on $RANK_SIZE device" | |||
| fi | |||
| echo "Working on file ${task}_mass_$DEVICE_ID" | |||
| cd $file_path || exit | |||
| cd ../ || exit | |||
| rm -rf ./${task}_mass_$DEVICE_ID | |||
| mkdir ./${task}_mass_$DEVICE_ID | |||
| cp train.py ./${task}_mass_$DEVICE_ID | |||
| cp eval.py ./${task}_mass_$DEVICE_ID | |||
| cp $configurations ./${task}_mass_$DEVICE_ID | |||
| if [ $vocab ] | |||
| then | |||
| cp $vocab ./${task}_mass_$DEVICE_ID | |||
| fi | |||
| cd ./${task}_mass_$DEVICE_ID || exit | |||
| env > log.log | |||
| echo $task | |||
| if [ "$task" == "train" ] | |||
| then | |||
| if [ $RANK_SIZE -gt 1 ] | |||
| then | |||
| mpirun -n $RANK_SIZE python train.py --config ${configurations##*/} --platform GPU >>log.log 2>&1 & | |||
| fi | |||
| python train.py --config ${configurations##*/} --platform GPU >>log.log 2>&1 & | |||
| elif [ "$task" == "infer" ] | |||
| then | |||
| python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} --metric ${metric} --platform GPU >>log_infer.log 2>&1 & | |||
| fi | |||
| cd ../ | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================ | |||
| """Transformer for training.""" | |||
| from mindspore import nn | |||
| import mindspore.context as context | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import composite as C | |||
| @@ -204,11 +205,16 @@ class TransformerNetworkWithLoss(nn.Cell): | |||
| grad_scale = C.MultitypeFuncGraph("grad_scale") | |||
| reciprocal = P.Reciprocal() | |||
| @grad_scale.register("Tensor", "Tensor") | |||
| def tensor_grad_scale(scale, grad): | |||
| return grad * F.cast(reciprocal(scale), F.dtype(grad)) | |||
| _grad_overflow = C.MultitypeFuncGraph("_grad_overflow") | |||
| grad_overflow = P.FloatStatus() | |||
| @_grad_overflow.register("Tensor") | |||
| def _tensor_grad_overflow(grad): | |||
| return grad_overflow(grad) | |||
| class TransformerTrainOneStepWithLossScaleCell(nn.Cell): | |||
| """ | |||
| @@ -251,9 +257,16 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): | |||
| self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) | |||
| self.clip_gradients = ClipGradients() | |||
| self.cast = P.Cast() | |||
| self.alloc_status = P.NPUAllocFloatStatus() | |||
| self.get_status = P.NPUGetFloatStatus() | |||
| self.clear_before_grad = P.NPUClearFloatStatus() | |||
| if context.get_context("device_target") == "GPU": | |||
| self.gpu_target = True | |||
| self.float_status = P.FloatStatus() | |||
| self.addn = P.AddN() | |||
| self.reshape = P.Reshape() | |||
| else: | |||
| self.gpu_target = False | |||
| self.alloc_status = P.NPUAllocFloatStatus() | |||
| self.get_status = P.NPUGetFloatStatus() | |||
| self.clear_status = P.NPUClearFloatStatus() | |||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||
| self.depend_parameter_use = P.ControlDepend(depend_mode=1) | |||
| self.base = Tensor(1, mstype.float32) | |||
| @@ -304,14 +317,18 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): | |||
| target_mask, | |||
| label_ids, | |||
| label_weights) | |||
| # Alloc status. | |||
| init = self.alloc_status() | |||
| # Clear overflow buffer. | |||
| self.clear_before_grad(init) | |||
| init = False | |||
| if not self.gpu_target: | |||
| # init overflow buffer | |||
| init = self.alloc_status() | |||
| # clear overflow buffer | |||
| self.clear_status(init) | |||
| if sens is None: | |||
| scaling_sens = self.loss_scale | |||
| else: | |||
| scaling_sens = sens | |||
| grads = self.grad(self.network, weights)(source_ids, | |||
| source_mask, | |||
| target_ids, | |||
| @@ -323,11 +340,21 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): | |||
| grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) | |||
| grads = self.clip_gradients(grads, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE) | |||
| if self.reducer_flag: | |||
| # Apply grad reducer on grads. | |||
| grads = self.grad_reducer(grads) | |||
| self.get_status(init) | |||
| flag_sum = self.reduce_sum(init, (0,)) | |||
| # get the overflow buffer | |||
| if not self.gpu_target: | |||
| self.get_status(init) | |||
| # sum overflow buffer elements, 0:not overflow , >0:overflow | |||
| flag_sum = self.reduce_sum(init, (0,)) | |||
| else: | |||
| flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) | |||
| flag_sum = self.addn(flag_sum) | |||
| # convert flag_sum to scalar | |||
| flag_sum = self.reshape(flag_sum, (())) | |||
| if self.is_distributed: | |||
| # Sum overflow flag over devices. | |||
| @@ -49,11 +49,13 @@ class LossCallBack(Callback): | |||
| file_name = "./loss.log" | |||
| with open(file_name, "a+") as f: | |||
| time_stamp_current = self._get_ms_timestamp() | |||
| f.write("time: {}, epoch: {}, step: {}, outputs are {}.\n".format( | |||
| f.write("time: {}, epoch: {}, step: {}, outputs are {},{},{}.\n".format( | |||
| time_stamp_current - self.time_stamp_first, | |||
| cb_params.cur_epoch_num, | |||
| cb_params.cur_step_num, | |||
| str(cb_params.net_outputs) | |||
| str(cb_params.net_outputs[0].asnumpy()), | |||
| str(cb_params.net_outputs[1].asnumpy()), | |||
| str(cb_params.net_outputs[2].asnumpy()) | |||
| )) | |||
| @staticmethod | |||
| @@ -24,7 +24,7 @@ from mindspore.common.tensor import Tensor | |||
| from mindspore.nn import Momentum | |||
| from mindspore.nn.optim import Adam, Lamb | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.loss_scale_manager import DynamicLossScaleManager | |||
| from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager | |||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint | |||
| from mindspore import context, ParallelMode, Parameter | |||
| from mindspore.communication import management as MultiAscend | |||
| @@ -41,18 +41,7 @@ from src.utils.lr_scheduler import polynomial_decay_scheduler, BertLearningRate | |||
| parser = argparse.ArgumentParser(description='MASS train entry point.') | |||
| parser.add_argument("--config", type=str, required=True, help="model config json file path.") | |||
| device_id = os.getenv('DEVICE_ID', None) | |||
| if device_id is None: | |||
| raise RuntimeError("`DEVICE_ID` can not be None.") | |||
| device_id = int(device_id) | |||
| context.set_context( | |||
| mode=context.GRAPH_MODE, | |||
| device_target="Ascend", | |||
| reserve_class_name_in_scope=False, | |||
| device_id=device_id) | |||
| parser.add_argument("--platform", type=str, required=True, help="model working platform.") | |||
| def get_config(config): | |||
| config = TransformerConfig.from_json_file(config) | |||
| @@ -79,12 +68,11 @@ def _train(model, config: TransformerConfig, | |||
| if pre_training_dataset is not None: | |||
| print(" | Start pre-training job.") | |||
| epoch_size = config.epochs * pre_training_dataset.get_dataset_size() // config.dataset_sink_step | |||
| if os.getenv("RANK_SIZE") is not None and int(os.getenv("RANK_SIZE")) > 1: | |||
| print(f" | Rank {MultiAscend.get_rank()} Call model train.") | |||
| model.train(epoch_size, pre_training_dataset, | |||
| model.train(config.epochs, pre_training_dataset, | |||
| callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode, | |||
| sink_size=config.dataset_sink_step) | |||
| @@ -97,9 +85,8 @@ def _train(model, config: TransformerConfig, | |||
| if fine_tune_dataset is not None: | |||
| print(" | Start fine-tuning job.") | |||
| epoch_size = config.epochs * fine_tune_dataset.get_dataset_size() // config.dataset_sink_step | |||
| model.train(epoch_size, fine_tune_dataset, | |||
| model.train(config.epochs, fine_tune_dataset, | |||
| callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode, | |||
| sink_size=config.dataset_sink_step) | |||
| @@ -114,7 +101,8 @@ def _train(model, config: TransformerConfig, | |||
| def _build_training_pipeline(config: TransformerConfig, | |||
| pre_training_dataset=None, | |||
| fine_tune_dataset=None, | |||
| test_dataset=None): | |||
| test_dataset=None, | |||
| platform="Ascend"): | |||
| """ | |||
| Build training pipeline. | |||
| @@ -198,14 +186,15 @@ def _build_training_pipeline(config: TransformerConfig, | |||
| else: | |||
| raise ValueError(f"optimizer only support `adam` and `momentum` now.") | |||
| # Dynamic loss scale. | |||
| scale_manager = DynamicLossScaleManager(init_loss_scale=config.init_loss_scale, | |||
| scale_factor=config.loss_scale_factor, | |||
| scale_window=config.scale_window) | |||
| net_with_grads = TransformerTrainOneStepWithLossScaleCell( | |||
| network=net_with_loss, optimizer=optimizer, | |||
| scale_update_cell=scale_manager.get_update_cell() | |||
| ) | |||
| # loss scale. | |||
| if platform == "Ascend": | |||
| scale_manager = DynamicLossScaleManager(init_loss_scale=config.init_loss_scale, | |||
| scale_factor=config.loss_scale_factor, | |||
| scale_window=config.scale_window) | |||
| else: | |||
| scale_manager = FixedLossScaleManager(loss_scale=1.0, drop_overflow_update=True) | |||
| net_with_grads = TransformerTrainOneStepWithLossScaleCell(network=net_with_loss, optimizer=optimizer, | |||
| scale_update_cell=scale_manager.get_update_cell()) | |||
| net_with_grads.set_train(True) | |||
| model = Model(net_with_grads) | |||
| loss_monitor = LossCallBack(config) | |||
| @@ -236,9 +225,12 @@ def _build_training_pipeline(config: TransformerConfig, | |||
| callbacks=callbacks) | |||
| def _setup_parallel_env(): | |||
| def _setup_parallel_env(platform): | |||
| context.reset_auto_parallel_context() | |||
| MultiAscend.init() | |||
| if platform == "GPU": | |||
| MultiAscend.init("nccl") | |||
| else: | |||
| MultiAscend.init() | |||
| context.set_auto_parallel_context( | |||
| parallel_mode=ParallelMode.DATA_PARALLEL, | |||
| device_num=MultiAscend.get_group_size(), | |||
| @@ -247,14 +239,14 @@ def _setup_parallel_env(): | |||
| ) | |||
| def train_parallel(config: TransformerConfig): | |||
| def train_parallel(config: TransformerConfig, platform: "Ascend"): | |||
| """ | |||
| Train model with multi ascend chips. | |||
| Args: | |||
| config (TransformerConfig): Config for MASS model. | |||
| """ | |||
| _setup_parallel_env() | |||
| _setup_parallel_env(platform) | |||
| print(f" | Starting training on {os.getenv('RANK_SIZE', None)} devices.") | |||
| @@ -286,10 +278,11 @@ def train_parallel(config: TransformerConfig): | |||
| _build_training_pipeline(config=config, | |||
| pre_training_dataset=pre_train_dataset, | |||
| fine_tune_dataset=fine_tune_dataset, | |||
| test_dataset=test_dataset) | |||
| test_dataset=test_dataset, | |||
| platform=platform) | |||
| def train_single(config: TransformerConfig): | |||
| def train_single(config: TransformerConfig, platform: "Ascend"): | |||
| """ | |||
| Train model on single device. | |||
| @@ -316,7 +309,8 @@ def train_single(config: TransformerConfig): | |||
| _build_training_pipeline(config=config, | |||
| pre_training_dataset=pre_train_dataset, | |||
| fine_tune_dataset=fine_tune_dataset, | |||
| test_dataset=test_dataset) | |||
| test_dataset=test_dataset, | |||
| platform=platform) | |||
| def _check_args(config): | |||
| @@ -327,9 +321,20 @@ def _check_args(config): | |||
| if __name__ == '__main__': | |||
| args, _ = parser.parse_known_args() | |||
| device_id = os.getenv('DEVICE_ID', None) | |||
| if device_id is None: | |||
| device_id = 0 | |||
| device_id = int(device_id) | |||
| context.set_context( | |||
| mode=context.GRAPH_MODE, | |||
| device_target=args.platform, | |||
| reserve_class_name_in_scope=False, | |||
| device_id=device_id) | |||
| _rank_size = os.getenv('RANK_SIZE') | |||
| args, _ = parser.parse_known_args() | |||
| _check_args(args.config) | |||
| _config = get_config(args.config) | |||
| @@ -337,6 +342,6 @@ if __name__ == '__main__': | |||
| context.set_context(save_graphs=_config.save_graphs) | |||
| if _rank_size is not None and int(_rank_size) > 1: | |||
| train_parallel(_config) | |||
| train_parallel(_config, args.platform) | |||
| else: | |||
| train_single(_config) | |||
| train_single(_config, args.platform) | |||