Merge pull request !652 from wanghua/mastertags/v0.3.0-alpha
| @@ -1,57 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| network config setting, will be used in train.py | |||
| """ | |||
| from easydict import EasyDict as edict | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.model_zoo.Bert_NEZHA import BertConfig | |||
| bert_train_cfg = edict({ | |||
| 'epoch_size': 10, | |||
| 'num_warmup_steps': 0, | |||
| 'start_learning_rate': 1e-4, | |||
| 'end_learning_rate': 0.0, | |||
| 'decay_steps': 1000, | |||
| 'power': 10.0, | |||
| 'save_checkpoint_steps': 2000, | |||
| 'keep_checkpoint_max': 10, | |||
| 'checkpoint_prefix': "checkpoint_bert", | |||
| # please add your own dataset path | |||
| 'DATA_DIR': "/your/path/examples.tfrecord", | |||
| # please add your own dataset schema path | |||
| 'SCHEMA_DIR': "/your/path/datasetSchema.json" | |||
| }) | |||
| bert_net_cfg = BertConfig( | |||
| batch_size=16, | |||
| seq_length=128, | |||
| vocab_size=21136, | |||
| hidden_size=1024, | |||
| num_hidden_layers=24, | |||
| num_attention_heads=16, | |||
| intermediate_size=4096, | |||
| hidden_act="gelu", | |||
| hidden_dropout_prob=0.0, | |||
| attention_probs_dropout_prob=0.0, | |||
| max_position_embeddings=512, | |||
| type_vocab_size=2, | |||
| initializer_range=0.02, | |||
| use_relative_positions=True, | |||
| input_mask_from_dataset=True, | |||
| token_type_ids_from_dataset=True, | |||
| dtype=mstype.float32, | |||
| compute_type=mstype.float16, | |||
| ) | |||
| @@ -1,96 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| NEZHA (NEural contextualiZed representation for CHinese lAnguage understanding) is the Chinese pretrained language | |||
| model currently based on BERT developed by Huawei. | |||
| 1. Prepare data | |||
| Following the data preparation as in BERT, run command as below to get dataset for training: | |||
| python ./create_pretraining_data.py \ | |||
| --input_file=./sample_text.txt \ | |||
| --output_file=./examples.tfrecord \ | |||
| --vocab_file=./your/path/vocab.txt \ | |||
| --do_lower_case=True \ | |||
| --max_seq_length=128 \ | |||
| --max_predictions_per_seq=20 \ | |||
| --masked_lm_prob=0.15 \ | |||
| --random_seed=12345 \ | |||
| --dupe_factor=5 | |||
| 2. Pretrain | |||
| First, prepare the distributed training environment, then adjust configurations in config.py, finally run train.py. | |||
| """ | |||
| import os | |||
| import numpy as np | |||
| from config import bert_train_cfg, bert_net_cfg | |||
| import mindspore.dataset.engine.datasets as de | |||
| import mindspore.dataset.transforms.c_transforms as C | |||
| from mindspore import context | |||
| from mindspore.common.tensor import Tensor | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor | |||
| from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell | |||
| from mindspore.nn.optim import Lamb | |||
| _current_dir = os.path.dirname(os.path.realpath(__file__)) | |||
| def create_train_dataset(batch_size): | |||
| """create train dataset""" | |||
| # apply repeat operations | |||
| repeat_count = bert_train_cfg.epoch_size | |||
| ds = de.TFRecordDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR, | |||
| columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", | |||
| "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"]) | |||
| type_cast_op = C.TypeCast(mstype.int32) | |||
| ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) | |||
| ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) | |||
| ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op) | |||
| ds = ds.map(input_columns="segment_ids", operations=type_cast_op) | |||
| ds = ds.map(input_columns="input_mask", operations=type_cast_op) | |||
| ds = ds.map(input_columns="input_ids", operations=type_cast_op) | |||
| # apply batch operations | |||
| ds = ds.batch(batch_size, drop_remainder=True) | |||
| ds = ds.repeat(repeat_count) | |||
| return ds | |||
| def weight_variable(shape): | |||
| """weight variable""" | |||
| np.random.seed(1) | |||
| ones = np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32) | |||
| return Tensor(ones) | |||
| def train_bert(): | |||
| """train bert""" | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| context.set_context(device_target="Ascend") | |||
| context.set_context(enable_task_sink=True) | |||
| context.set_context(enable_loop_sink=True) | |||
| context.set_context(enable_mem_reuse=True) | |||
| ds = create_train_dataset(bert_net_cfg.batch_size) | |||
| netwithloss = BertNetworkWithLoss(bert_net_cfg, True) | |||
| optimizer = Lamb(netwithloss.trainable_params(), decay_steps=bert_train_cfg.decay_steps, | |||
| start_learning_rate=bert_train_cfg.start_learning_rate, | |||
| end_learning_rate=bert_train_cfg.end_learning_rate, power=bert_train_cfg.power, | |||
| warmup_steps=bert_train_cfg.num_warmup_steps, decay_filter=lambda x: False) | |||
| netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) | |||
| netwithgrads.set_train(True) | |||
| model = Model(netwithgrads) | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=bert_train_cfg.save_checkpoint_steps, | |||
| keep_checkpoint_max=bert_train_cfg.keep_checkpoint_max) | |||
| ckpoint_cb = ModelCheckpoint(prefix=bert_train_cfg.checkpoint_prefix, config=config_ck) | |||
| model.train(ds.get_repeat_count(), ds, callbacks=[LossMonitor(), ckpoint_cb], dataset_sink_mode=False) | |||
| if __name__ == '__main__': | |||
| train_bert() | |||
| @@ -4,20 +4,26 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base]( | |||
| ## Requirements | |||
| - Install [MindSpore](https://www.mindspore.cn/install/en). | |||
| - Download the zhwiki dataset from <https://dumps.wikimedia.org/zhwiki> for pre-training. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wiliextractor). Convert the dataset to TFRecord format and move the files to a specified path. | |||
| - Download the zhwiki dataset from <https://dumps.wikimedia.org/zhwiki> for pre-training. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wil | |||
| kiextractor). Convert the dataset to TFRecord format and move the files to a specified path. | |||
| - Download the CLUE dataset from <https://www.cluebenchmarks.com> for fine-tuning and evaluation. | |||
| > Notes: | |||
| If you are running a fine-tuning or evaluation task, prepare the corresponding checkpoint file. | |||
| ## Running the Example | |||
| ### Pre-Training | |||
| - Set options in `config.py`. Make sure the 'DATA_DIR'(path to the dataset) and 'SCHEMA_DIR'(path to the json schema file) are set to your own path. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#tfrecord) for more information about dataset and the json schema file. | |||
| - Set options in `config.py`, including lossscale, optimizer and network. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#tfrecord) for more information about dataset and the json schema file. | |||
| - Run `run_pretrain.py` for pre-training of BERT-base and BERT-NEZHA model. | |||
| - Run `run_standalone_pretrain.sh` for non-distributed pre-training of BERT-base and BERT-NEZHA model. | |||
| ``` bash | |||
| python run_pretrain.py --backend=ms | |||
| ``` bash | |||
| sh run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_PATH | |||
| ``` | |||
| - Run `run_distribute_pretrain.sh` for distributed pre-training of BERT-base and BERT-NEZHA model. | |||
| ``` bash | |||
| sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH MINDSPORE_PATH | |||
| ``` | |||
| ### Fine-Tuning | |||
| - Set options in `finetune_config.py`. Make sure the 'data_file', 'schema_file' and 'ckpt_file' are set to your own path, set the 'pre_training_ckpt' to save the checkpoint files generated. | |||
| @@ -40,30 +46,42 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base]( | |||
| ## Usage | |||
| ### Pre-Training | |||
| ``` | |||
| usage: run_pretrain.py [--backend BACKEND] | |||
| optional parameters: | |||
| --backend, BACKEND MindSpore backend: ms | |||
| usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [--device_id N] | |||
| [--enable_task_sink ENABLE_TASK_SINK] [--enable_loop_sink ENABLE_LOOP_SINK] | |||
| [--enable_mem_reuse ENABLE_MEM_REUSE] [--enable_save_ckpt ENABLE_SAVE_CKPT] | |||
| [--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE] | |||
| [--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] [--checkpoint_path CHECKPOINT_PATH] | |||
| [--save_checkpoint_steps N] [--save_checkpoint_num N] | |||
| [--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR] | |||
| options: | |||
| --distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false" | |||
| --epoch_size epoch size: N, default is 1 | |||
| --device_num number of used devices: N, default is 1 | |||
| --device_id device id: N, default is 0 | |||
| --enable_task_sink enable task sink: "true" | "false", default is "true" | |||
| --enable_loop_sink enable loop sink: "true" | "false", default is "true" | |||
| --enable_mem_reuse enable memory reuse: "true" | "false", default is "true" | |||
| --enable_save_ckpt enable save checkpoint: "true" | "false", default is "true" | |||
| --enable_lossscale enable lossscale: "true" | "false", default is "true" | |||
| --do_shuffle enable shuffle: "true" | "false", default is "true" | |||
| --enable_data_sink enable data sink: "true" | "false", default is "true" | |||
| --data_sink_steps set data sink steps: N, default is 1 | |||
| --checkpoint_path path to save checkpoint files: PATH, default is "" | |||
| --save_checkpoint_steps steps for saving checkpoint files: N, default is 1000 | |||
| --save_checkpoint_num number for saving checkpoint files: N, default is 1 | |||
| --data_dir path to dataset directory: PATH, default is "" | |||
| --schema_dir path to schema.json file, PATH, default is "" | |||
| ``` | |||
| ## Options and Parameters | |||
| It contains of parameters of BERT model and options for training, which is set in file `config.py`, `finetune_config.py` and `evaluation_config.py` respectively. | |||
| ### Options: | |||
| ``` | |||
| Pre-Training: | |||
| bert_network version of BERT model: base | large, default is base | |||
| epoch_size repeat counts of training: N, default is 40 | |||
| dataset_sink_mode use dataset sink mode or not: True | False, default is True | |||
| do_shuffle shuffle the dataset or not: True | False, default is True | |||
| do_train_with_lossscale use lossscale or not: True | False, default is True | |||
| loss_scale_value initial value of loss scale: N, default is 2^32 | |||
| scale_factor factor used to update loss scale: N, default is 2 | |||
| scale_window steps for once updatation of loss scale: N, default is 1000 | |||
| save_checkpoint_steps steps to save a checkpoint: N, default is 2000 | |||
| keep_checkpoint_max numbers to save checkpoint: N, default is 1 | |||
| init_ckpt checkpoint file to load: PATH, default is "" | |||
| data_dir dataset file to load: PATH, default is "/your/path/cn-wiki-128" | |||
| schema_dir dataset schema file to load: PATH, default is "your/path/datasetSchema.json" | |||
| scale_window steps for once updatation of loss scale: N, default is 1000 | |||
| optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum, default is "Lamb" | |||
| Fine-Tuning: | |||
| @@ -0,0 +1,89 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| network config setting, will be used in dataset.py, run_pretrain.py | |||
| """ | |||
| from easydict import EasyDict as edict | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.model_zoo.Bert_NEZHA import BertConfig | |||
| cfg = edict({ | |||
| 'bert_network': 'base', | |||
| 'loss_scale_value': 2**32, | |||
| 'scale_factor': 2, | |||
| 'scale_window': 1000, | |||
| 'optimizer': 'Lamb', | |||
| 'AdamWeightDecayDynamicLR': edict({ | |||
| 'learning_rate': 3e-5, | |||
| 'end_learning_rate': 0.0, | |||
| 'power': 5.0, | |||
| 'weight_decay': 1e-5, | |||
| 'eps': 1e-6, | |||
| }), | |||
| 'Lamb': edict({ | |||
| 'start_learning_rate': 3e-5, | |||
| 'end_learning_rate': 0.0, | |||
| 'power': 10.0, | |||
| 'warmup_steps': 10000, | |||
| 'weight_decay': 0.01, | |||
| 'eps': 1e-6, | |||
| 'decay_filter': lambda x: False, | |||
| }), | |||
| 'Momentum': edict({ | |||
| 'learning_rate': 2e-5, | |||
| 'momentum': 0.9, | |||
| }), | |||
| }) | |||
| if cfg.bert_network == 'base': | |||
| bert_net_cfg = BertConfig( | |||
| batch_size=16, | |||
| seq_length=128, | |||
| vocab_size=21136, | |||
| hidden_size=768, | |||
| num_hidden_layers=12, | |||
| num_attention_heads=12, | |||
| intermediate_size=3072, | |||
| hidden_act="gelu", | |||
| hidden_dropout_prob=0.1, | |||
| attention_probs_dropout_prob=0.1, | |||
| max_position_embeddings=512, | |||
| type_vocab_size=2, | |||
| initializer_range=0.02, | |||
| use_relative_positions=False, | |||
| input_mask_from_dataset=True, | |||
| token_type_ids_from_dataset=True, | |||
| dtype=mstype.float32, | |||
| compute_type=mstype.float16, | |||
| ) | |||
| else: | |||
| bert_net_cfg = BertConfig( | |||
| batch_size=16, | |||
| seq_length=128, | |||
| vocab_size=21136, | |||
| hidden_size=1024, | |||
| num_hidden_layers=12, | |||
| num_attention_heads=16, | |||
| intermediate_size=4096, | |||
| hidden_act="gelu", | |||
| hidden_dropout_prob=0.1, | |||
| attention_probs_dropout_prob=0.1, | |||
| max_position_embeddings=512, | |||
| type_vocab_size=2, | |||
| initializer_range=0.02, | |||
| use_relative_positions=True, | |||
| input_mask_from_dataset=True, | |||
| token_type_ids_from_dataset=True, | |||
| dtype=mstype.float32, | |||
| compute_type=mstype.float16, | |||
| ) | |||
| @@ -0,0 +1,58 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Data operations, will be used in run_pretrain.py | |||
| """ | |||
| import os | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset.engine.datasets as de | |||
| import mindspore.dataset.transforms.c_transforms as C | |||
| from mindspore import log as logger | |||
| from config import bert_net_cfg | |||
| def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", enable_data_sink="true", | |||
| data_sink_steps=1, data_dir=None, schema_dir=None): | |||
| """create train dataset""" | |||
| # apply repeat operations | |||
| repeat_count = epoch_size | |||
| files = os.listdir(data_dir) | |||
| data_files = [] | |||
| for file_name in files: | |||
| data_files.append(data_dir+file_name) | |||
| ds = de.TFRecordDataset(data_files, schema_dir, | |||
| columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", | |||
| "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"], | |||
| shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, | |||
| shard_equal_rows=True) | |||
| ori_dataset_size = ds.get_dataset_size() | |||
| new_size = ori_dataset_size | |||
| if enable_data_sink == "true": | |||
| new_size = data_sink_steps * bert_net_cfg.batch_size | |||
| ds.set_dataset_size(new_size) | |||
| repeat_count = int(repeat_count * ori_dataset_size // ds.get_dataset_size()) | |||
| type_cast_op = C.TypeCast(mstype.int32) | |||
| ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) | |||
| ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) | |||
| ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op) | |||
| ds = ds.map(input_columns="segment_ids", operations=type_cast_op) | |||
| ds = ds.map(input_columns="input_mask", operations=type_cast_op) | |||
| ds = ds.map(input_columns="input_ids", operations=type_cast_op) | |||
| # apply batch operations | |||
| ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True) | |||
| ds = ds.repeat(repeat_count) | |||
| logger.info("data size: {}".format(ds.get_dataset_size())) | |||
| logger.info("repeatcount: {}".format(ds.get_repeat_count())) | |||
| return ds | |||
| @@ -0,0 +1,66 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| echo "==============================================================================================================" | |||
| echo "Please run the scipt as: " | |||
| echo "sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH MINDSPORE_PATH" | |||
| echo "for example: sh run_distribute_pretrain.sh 8 40 /path/zh-wiki/ /path/Schema.json /path/hccl.json /path/mindspore" | |||
| echo "It is better to use absolute path." | |||
| echo "==============================================================================================================" | |||
| EPOCH_SIZE=$2 | |||
| DATA_DIR=$3 | |||
| SCHEMA_DIR=$4 | |||
| MINDSPORE_PATH=$6 | |||
| export PYTHONPATH=$MINDSPORE_PATH/build/package:$PYTHONPATH | |||
| export MINDSPORE_HCCL_CONFIG_PATH=$5 | |||
| export RANK_SIZE=$1 | |||
| for((i=0;i<RANK_SIZE;i++)) | |||
| do | |||
| export DEVICE_ID=$i | |||
| start=`expr $i \* 12` | |||
| end=`expr $start \+ 11` | |||
| cmdopt=$start"-"$end | |||
| rm -rf LOG$i | |||
| mkdir ./LOG$i | |||
| cp *.py ./LOG$i | |||
| cd ./LOG$i || exit | |||
| export RANK_ID=$i | |||
| echo "start training for rank $i, device $DEVICE_ID" | |||
| env > env.log | |||
| taskset -c $cmdopt python ../run_pretrain.py \ | |||
| --distribute="true" \ | |||
| --epoch_size=$EPOCH_SIZE \ | |||
| --device_id=$DEVICE_ID \ | |||
| --device_num=$RANK_SIZE \ | |||
| --enable_task_sink="true" \ | |||
| --enable_loop_sink="true" \ | |||
| --enable_mem_reuse="true" \ | |||
| --enable_save_ckpt="true" \ | |||
| --enable_lossscale="true" \ | |||
| --do_shuffle="true" \ | |||
| --enable_data_sink="true" \ | |||
| --data_sink_steps=1 \ | |||
| --checkpoint_path="" \ | |||
| --save_checkpoint_steps=1000 \ | |||
| --save_checkpoint_num=1 \ | |||
| --data_dir=$DATA_DIR \ | |||
| --schema_dir=$SCHEMA_DIR > log.txt 2>&1 & | |||
| cd ../ | |||
| done | |||
| @@ -0,0 +1,144 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| #################pre_train bert example on zh-wiki######################## | |||
| python run_pretrain.py | |||
| """ | |||
| import os | |||
| import argparse | |||
| import mindspore.communication.management as D | |||
| from mindspore import context | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.parallel_utils import ParallelMode | |||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||
| from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig | |||
| from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell | |||
| from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR | |||
| from dataset import create_bert_dataset | |||
| from config import cfg, bert_net_cfg | |||
| _current_dir = os.path.dirname(os.path.realpath(__file__)) | |||
| class LossCallBack(Callback): | |||
| """ | |||
| Monitor the loss in training. | |||
| If the loss in NAN or INF terminating training. | |||
| Note: | |||
| if per_print_times is 0 do not print loss. | |||
| Args: | |||
| per_print_times (int): Print loss every times. Default: 1. | |||
| """ | |||
| def __init__(self, per_print_times=1): | |||
| super(LossCallBack, self).__init__() | |||
| if not isinstance(per_print_times, int) or per_print_times < 0: | |||
| raise ValueError("print_step must be int and >= 0") | |||
| self._per_print_times = per_print_times | |||
| def step_end(self, run_context): | |||
| cb_params = run_context.original_args() | |||
| with open("./loss.log", "a+") as f: | |||
| f.write("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, | |||
| str(cb_params.net_outputs))) | |||
| f.write('\n') | |||
| def run_pretrain(): | |||
| """pre-train bert_clue""" | |||
| parser = argparse.ArgumentParser(description='bert pre_training') | |||
| parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") | |||
| parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.") | |||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | |||
| parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") | |||
| parser.add_argument("--enable_task_sink", type=str, default="true", help="Enable task sink, default is true.") | |||
| parser.add_argument("--enable_loop_sink", type=str, default="true", help="Enable loop sink, default is true.") | |||
| parser.add_argument("--enable_mem_reuse", type=str, default="true", help="Enable mem reuse, default is true.") | |||
| parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.") | |||
| parser.add_argument("--enable_lossscale", type=str, default="true", help="Use lossscale or not, default is not.") | |||
| parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") | |||
| parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") | |||
| parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.") | |||
| parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path") | |||
| parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " | |||
| "default is 1000.") | |||
| parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.") | |||
| parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") | |||
| parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") | |||
| args_opt = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||
| context.set_context(enable_task_sink=(args_opt.enable_task_sink == "true"), | |||
| enable_loop_sink=(args_opt.enable_loop_sink == "true"), | |||
| enable_mem_reuse=(args_opt.enable_mem_reuse == "true")) | |||
| context.set_context(reserve_class_name_in_scope=False) | |||
| if args_opt.distribute == "true": | |||
| device_num = args_opt.device_num | |||
| context.reset_auto_parallel_context() | |||
| context.set_context(enable_hccl=True) | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, | |||
| device_num=device_num) | |||
| D.init() | |||
| rank = args_opt.device_id % device_num | |||
| else: | |||
| context.set_context(enable_hccl=False) | |||
| rank = 0 | |||
| device_num = 1 | |||
| ds = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle, args_opt.enable_data_sink, | |||
| args_opt.data_sink_steps, args_opt.data_dir, args_opt.schema_dir) | |||
| netwithloss = BertNetworkWithLoss(bert_net_cfg, True) | |||
| if cfg.optimizer == 'Lamb': | |||
| optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * ds.get_repeat_count(), | |||
| start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate, | |||
| power=cfg.Lamb.power, warmup_steps=cfg.Lamb.warmup_steps, weight_decay=cfg.Lamb.weight_decay, | |||
| eps=cfg.Lamb.eps, decay_filter=cfg.Lamb.decay_filter) | |||
| elif cfg.optimizer == 'Momentum': | |||
| optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, | |||
| momentum=cfg.Momentum.momentum) | |||
| elif cfg.optimizer == 'AdamWeightDecayDynamicLR': | |||
| optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), | |||
| decay_steps=ds.get_dataset_size() * ds.get_repeat_count(), | |||
| learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate, | |||
| end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate, | |||
| power=cfg.AdamWeightDecayDynamicLR.power, | |||
| weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay, | |||
| eps=cfg.AdamWeightDecayDynamicLR.eps) | |||
| else: | |||
| raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecayDynamicLR]". | |||
| format(cfg.optimizer)) | |||
| callback = [LossCallBack()] | |||
| if args_opt.enable_save_ckpt == "true": | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps, | |||
| keep_checkpoint_max=args_opt.save_checkpoint_num) | |||
| ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', config=config_ck) | |||
| callback.append(ckpoint_cb) | |||
| if args_opt.checkpoint_path: | |||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | |||
| load_param_into_net(netwithloss, param_dict) | |||
| if args_opt.enable_lossscale == "true": | |||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, | |||
| scale_factor=cfg.scale_factor, | |||
| scale_window=cfg.scale_window) | |||
| netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, | |||
| scale_update_cell=update_cell) | |||
| else: | |||
| netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) | |||
| model = Model(netwithgrads) | |||
| model.train(ds.get_repeat_count(), ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true")) | |||
| if __name__ == '__main__': | |||
| run_pretrain() | |||
| @@ -0,0 +1,46 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| echo "==============================================================================================================" | |||
| echo "Please run the scipt as: " | |||
| echo "sh run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_PATH" | |||
| echo "for example: sh run_standalone_pretrain.sh 0 40 /path/zh-wiki/ /path/Schema.json /path/mindspore" | |||
| echo "==============================================================================================================" | |||
| DEVICE_ID=$1 | |||
| EPOCH_SIZE=$2 | |||
| DATA_DIR=$3 | |||
| SCHEMA_DIR=$4 | |||
| MINDSPORE_PATH=$5 | |||
| export PYTHONPATH=$MINDSPORE_PATH/build/package:$PYTHONPATH | |||
| python run_pretrain.py \ | |||
| --distribute="false" \ | |||
| --epoch_size=$EPOCH_SIZE \ | |||
| --device_id=$DEVICE_ID \ | |||
| --enable_task_sink="true" \ | |||
| --enable_loop_sink="true" \ | |||
| --enable_mem_reuse="true" \ | |||
| --enable_save_ckpt="true" \ | |||
| --enable_lossscale="true" \ | |||
| --do_shuffle="true" \ | |||
| --enable_data_sink="true" \ | |||
| --data_sink_steps=1 \ | |||
| --checkpoint_path="" \ | |||
| --save_checkpoint_steps=1000 \ | |||
| --save_checkpoint_num=1 \ | |||
| --data_dir=$DATA_DIR \ | |||
| --schema_dir=$SCHEMA_DIR > log.txt 2>&1 & | |||