Merge pull request !4460 from zongha/mastertags/v0.7.0-beta
| @@ -11,14 +11,14 @@ This is an example of training bert by second-order optimizer THOR. THOR is a no | |||||
| ## Running the Example | ## Running the Example | ||||
| ### Pre-Training | ### Pre-Training | ||||
| - Set options in `config.py`, including lossscale, optimizer and network. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#tfrecord) for more information about dataset and the json schema file. | |||||
| - Set options in `config.py`, including optimizer and network. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#tfrecord) for more information about dataset and the json schema file. | |||||
| - Run `run_standalone_pretrain.sh` for non-distributed pre-training of BERT-base and BERT-NEZHA model. | |||||
| - Run `run_standalone_pretrain.sh` for non-distributed pre-training of BERT-base, BERT-NEZHA and BERT-large model. | |||||
| ``` bash | ``` bash | ||||
| sh scripts/run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR | sh scripts/run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR | ||||
| ``` | ``` | ||||
| - Run `run_distribute_pretrain.sh` for distributed pre-training of BERT-base and BERT-NEZHA model. | |||||
| - Run `run_distribute_pretrain.sh` for distributed pre-training of BERT-base, BERT-NEZHA and BERT-large model. | |||||
| ``` bash | ``` bash | ||||
| sh scripts/run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR RANK_TABLE_FILE | sh scripts/run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR RANK_TABLE_FILE | ||||
| @@ -30,7 +30,7 @@ This is an example of training bert by second-order optimizer THOR. THOR is a no | |||||
| usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [--device_id N] | usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [--device_id N] | ||||
| [--enable_save_ckpt ENABLE_SAVE_CKPT] | [--enable_save_ckpt ENABLE_SAVE_CKPT] | ||||
| [--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE] | [--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE] | ||||
| [--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] [--checkpoint_path CHECKPOINT_PATH] | |||||
| [--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] [--save_checkpoint_path CHECKPOINT_PATH] | |||||
| [--save_checkpoint_steps N] [--save_checkpoint_num N] | [--save_checkpoint_steps N] [--save_checkpoint_num N] | ||||
| [--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR] | [--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR] | ||||
| @@ -44,7 +44,7 @@ options: | |||||
| --do_shuffle enable shuffle: "true" | "false", default is "true" | --do_shuffle enable shuffle: "true" | "false", default is "true" | ||||
| --enable_data_sink enable data sink: "true" | "false", default is "true" | --enable_data_sink enable data sink: "true" | "false", default is "true" | ||||
| --data_sink_steps set data sink steps: N, default is 1 | --data_sink_steps set data sink steps: N, default is 1 | ||||
| --checkpoint_path path to save checkpoint files: PATH, default is "" | |||||
| --save_checkpoint_path path to save checkpoint files: PATH, default is "" | |||||
| --save_checkpoint_steps steps for saving checkpoint files: N, default is 1000 | --save_checkpoint_steps steps for saving checkpoint files: N, default is 1000 | ||||
| --save_checkpoint_num number for saving checkpoint files: N, default is 1 | --save_checkpoint_num number for saving checkpoint files: N, default is 1 | ||||
| --data_dir path to dataset directory: PATH, default is "" | --data_dir path to dataset directory: PATH, default is "" | ||||
| @@ -55,7 +55,7 @@ It contains of parameters of BERT model and options for training, which is set i | |||||
| ### Options: | ### Options: | ||||
| ``` | ``` | ||||
| config.py: | config.py: | ||||
| bert_network version of BERT model: base | nezha, default is base | |||||
| bert_network version of BERT model: base | nezha | large, default is large | |||||
| optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum | Thor, default is "Thor" | optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum | Thor, default is "Thor" | ||||
| ``` | ``` | ||||
| @@ -63,7 +63,7 @@ config.py: | |||||
| ### Parameters: | ### Parameters: | ||||
| ``` | ``` | ||||
| Parameters for dataset and network (Pre-Training/Evaluation): | Parameters for dataset and network (Pre-Training/Evaluation): | ||||
| batch_size batch size of input dataset: N, default is 8 | |||||
| batch_size batch size of input dataset: N, default is 12 | |||||
| seq_length length of input sequence: N, default is 128 | seq_length length of input sequence: N, default is 128 | ||||
| vocab_size size of each embedding vector: N, must be consistant with the dataset you use. Default is 21136 | vocab_size size of each embedding vector: N, must be consistant with the dataset you use. Default is 21136 | ||||
| hidden_size size of bert encoder layers: N, default is 768 | hidden_size size of bert encoder layers: N, default is 768 | ||||
| @@ -87,7 +87,7 @@ Parameters for optimizer: | |||||
| momentum momentum for the moving average: Q | momentum momentum for the moving average: Q | ||||
| weight_decay weight decay: Q | weight_decay weight decay: Q | ||||
| loss_scale loss scale: N | loss_scale loss scale: N | ||||
| frequency the step interval to update second-order information matrix: N, default is 10 | |||||
| batch_size batch size of input dataset: N, default is 8 | |||||
| frequency the step interval to update second-order information matrix: N, default is 100 | |||||
| batch_size batch size of input dataset: N, default is 12 | |||||
| ``` | ``` | ||||
| @@ -19,7 +19,6 @@ python run_pretrain.py | |||||
| import argparse | import argparse | ||||
| import os | import os | ||||
| import numpy | import numpy | ||||
| from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell | from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell | ||||
| from src.bert_net_config import bert_net_cfg | from src.bert_net_config import bert_net_cfg | ||||
| @@ -27,10 +26,8 @@ from src.config import cfg | |||||
| from src.dataset import create_bert_dataset | from src.dataset import create_bert_dataset | ||||
| from src.lr_generator import get_bert_lr, get_bert_damping | from src.lr_generator import get_bert_lr, get_bert_damping | ||||
| from src.model_thor import Model | from src.model_thor import Model | ||||
| # from src.thor_for_bert import THOR | |||||
| from src.thor_for_bert_arg import THOR | from src.thor_for_bert_arg import THOR | ||||
| from src.utils import LossCallBack, BertLearningRate | from src.utils import LossCallBack, BertLearningRate | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| import mindspore.communication.management as D | import mindspore.communication.management as D | ||||
| from mindspore import context | from mindspore import context | ||||
| @@ -69,8 +66,8 @@ def run_pretrain(): | |||||
| parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") | parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") | ||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id, | |||||
| save_graphs=True) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, | |||||
| device_id=args_opt.device_id, save_graphs=False) | |||||
| context.set_context(reserve_class_name_in_scope=False) | context.set_context(reserve_class_name_in_scope=False) | ||||
| context.set_context(variable_memory_max_size="30GB") | context.set_context(variable_memory_max_size="30GB") | ||||
| ckpt_save_dir = args_opt.save_checkpoint_path | ckpt_save_dir = args_opt.save_checkpoint_path | ||||
| @@ -165,15 +162,13 @@ def run_pretrain(): | |||||
| optimizer = THOR(filter(lambda x: x.requires_grad, net_with_loss.get_parameters()), lr, cfg.Thor.momentum, | optimizer = THOR(filter(lambda x: x.requires_grad, net_with_loss.get_parameters()), lr, cfg.Thor.momentum, | ||||
| filter(lambda x: 'matrix_A' in x.name, net_with_loss.get_parameters()), | filter(lambda x: 'matrix_A' in x.name, net_with_loss.get_parameters()), | ||||
| filter(lambda x: 'matrix_G' in x.name, net_with_loss.get_parameters()), | filter(lambda x: 'matrix_G' in x.name, net_with_loss.get_parameters()), | ||||
| filter(lambda x: 'A_inv_max' in x.name, net_with_loss.get_parameters()), | |||||
| filter(lambda x: 'G_inv_max' in x.name, net_with_loss.get_parameters()), | |||||
| cfg.Thor.weight_decay, cfg.Thor.loss_scale, bert_net_cfg.num_hidden_layers, | cfg.Thor.weight_decay, cfg.Thor.loss_scale, bert_net_cfg.num_hidden_layers, | ||||
| bert_net_cfg.batch_size, damping) | bert_net_cfg.batch_size, damping) | ||||
| else: | else: | ||||
| raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]". | |||||
| raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]". | |||||
| format(cfg.optimizer)) | format(cfg.optimizer)) | ||||
| callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()] | callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()] | ||||
| if args_opt.enable_save_ckpt == "true": | |||||
| if args_opt.enable_save_ckpt == "true" and rank == 0: | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps, | config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps, | ||||
| keep_checkpoint_max=args_opt.save_checkpoint_num) | keep_checkpoint_max=args_opt.save_checkpoint_num) | ||||
| ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', directory=ckpt_save_dir, config=config_ck) | ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', directory=ckpt_save_dir, config=config_ck) | ||||
| @@ -37,25 +37,26 @@ do | |||||
| rm -rf LOG$i | rm -rf LOG$i | ||||
| mkdir ./LOG$i | mkdir ./LOG$i | ||||
| cp *.py ./LOG$i | |||||
| cp -r src ./LOG$i | |||||
| cp ../*.py ./LOG$i | |||||
| cp -r ../src ./LOG$i | |||||
| cd ./LOG$i || exit | cd ./LOG$i || exit | ||||
| echo "start training for rank $i, device $DEVICE_ID" | |||||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | |||||
| env > env.log | env > env.log | ||||
| python ../run_pretrain.py \ | |||||
| python run_pretrain.py \ | |||||
| --distribute="true" \ | --distribute="true" \ | ||||
| --epoch_size=$EPOCH_SIZE \ | --epoch_size=$EPOCH_SIZE \ | ||||
| --device_id=$DEVICE_ID \ | --device_id=$DEVICE_ID \ | ||||
| --device_num=$RANK_SIZE \ | --device_num=$RANK_SIZE \ | ||||
| --enable_save_ckpt="true" \ | --enable_save_ckpt="true" \ | ||||
| --enable_lossscale="false" \ | --enable_lossscale="false" \ | ||||
| --do_shuffle="true" \ | |||||
| --do_shuffle="false" \ | |||||
| --enable_data_sink="true" \ | --enable_data_sink="true" \ | ||||
| --data_sink_steps=1000 \ | --data_sink_steps=1000 \ | ||||
| --load_checkpoint_path="" \ | --load_checkpoint_path="" \ | ||||
| --save_checkpoint_steps=5000 \ | |||||
| --save_checkpoint_path='./' \ | |||||
| --save_checkpoint_steps=1000 \ | |||||
| --save_checkpoint_num=30 \ | --save_checkpoint_num=30 \ | ||||
| --data_dir=$DATA_DIR \ | --data_dir=$DATA_DIR \ | ||||
| --schema_dir=$SCHEMA_DIR > log.txt 2>&1 & | --schema_dir=$SCHEMA_DIR > log.txt 2>&1 & | ||||
| cd ../ | cd ../ | ||||
| done | |||||
| done | |||||
| @@ -20,27 +20,39 @@ echo "bash run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR" | |||||
| echo "for example: bash run_standalone_pretrain.sh 0 40 /path/zh-wiki/ /path/Schema.json" | echo "for example: bash run_standalone_pretrain.sh 0 40 /path/zh-wiki/ /path/Schema.json" | ||||
| echo "==============================================================================================================" | echo "==============================================================================================================" | ||||
| DEVICE_ID=$1 | |||||
| EPOCH_SIZE=$2 | EPOCH_SIZE=$2 | ||||
| DATA_DIR=$3 | DATA_DIR=$3 | ||||
| SCHEMA_DIR=$4 | SCHEMA_DIR=$4 | ||||
| mkdir -p ms_log | |||||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||||
| CUR_DIR=`pwd` | |||||
| export GLOG_log_dir=${CUR_DIR}/ms_log | |||||
| export GLOG_logtostderr=0 | |||||
| python ${PROJECT_DIR}/../run_pretrain.py \ | |||||
| --distribute="false" \ | |||||
| --epoch_size=$EPOCH_SIZE \ | |||||
| --device_id=$DEVICE_ID \ | |||||
| --enable_save_ckpt="true" \ | |||||
| --enable_lossscale="true" \ | |||||
| --do_shuffle="true" \ | |||||
| --enable_data_sink="true" \ | |||||
| --data_sink_steps=1 \ | |||||
| --load_checkpoint_path="" \ | |||||
| --save_checkpoint_steps=10000 \ | |||||
| --save_checkpoint_num=1 \ | |||||
| --data_dir=$DATA_DIR \ | |||||
| --schema_dir=$SCHEMA_DIR > log.txt 2>&1 & | |||||
| ulimit -u unlimited | |||||
| export DEVICE_ID=$1 | |||||
| export RANK_SIZE=1 | |||||
| if [ -d "LOG" ]; | |||||
| then | |||||
| rm -rf ./LOG | |||||
| fi | |||||
| mkdir ./LOG | |||||
| cp ../*.py ./LOG | |||||
| cp -r ../src ./LOG | |||||
| cd ./LOG || exit | |||||
| echo "start training for device $DEVICE_ID" | |||||
| env > env.log | |||||
| python run_pretrain.py \ | |||||
| --distribute="false" \ | |||||
| --epoch_size=$EPOCH_SIZE \ | |||||
| --device_id=$DEVICE_ID \ | |||||
| --device_num=$RANK_SIZE \ | |||||
| --enable_save_ckpt="true" \ | |||||
| --enable_lossscale="false" \ | |||||
| --do_shuffle="false" \ | |||||
| --enable_data_sink="true" \ | |||||
| --data_sink_steps=1000 \ | |||||
| --load_checkpoint_path="" \ | |||||
| --save_checkpoint_path='./' \ | |||||
| --save_checkpoint_steps=5000 \ | |||||
| --save_checkpoint_num=20 \ | |||||
| --data_dir=$DATA_DIR \ | |||||
| --schema_dir=$SCHEMA_DIR > log.txt 2>&1 & | |||||
| cd ../ | |||||
| @@ -35,6 +35,8 @@ from .thor_layer import Dense_Thor | |||||
| damping = get_bert_damping() | damping = get_bert_damping() | ||||
| loss_scale = cfg.Thor.loss_scale | loss_scale = cfg.Thor.loss_scale | ||||
| frequency = cfg.Thor.frequency | |||||
| batch_size = cfg.Thor.batch_size | |||||
| GRADIENT_CLIP_TYPE = 1 | GRADIENT_CLIP_TYPE = 1 | ||||
| GRADIENT_CLIP_VALUE = 1.0 | GRADIENT_CLIP_VALUE = 1.0 | ||||
| @@ -91,9 +93,9 @@ class GetMaskedLMOutput(nn.Cell): | |||||
| bias_init='zeros', | bias_init='zeros', | ||||
| damping=damping, | damping=damping, | ||||
| loss_scale=loss_scale, | loss_scale=loss_scale, | ||||
| frequency=1, | |||||
| frequency=frequency, | |||||
| activation=config.hidden_act, | activation=config.hidden_act, | ||||
| batch_size=config.batch_size).to_float(config.compute_type) | |||||
| batch_size=batch_size).to_float(config.compute_type) | |||||
| self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type) | self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type) | ||||
| self.output_bias = Parameter( | self.output_bias = Parameter( | ||||
| initializer( | initializer( | ||||
| @@ -34,6 +34,7 @@ from .thor_layer import Dense_Thor, Embedding_Thor | |||||
| damping = get_bert_damping() | damping = get_bert_damping() | ||||
| loss_scale = cfg.Thor.loss_scale | loss_scale = cfg.Thor.loss_scale | ||||
| frequency = cfg.Thor.frequency | |||||
| batch_size = cfg.Thor.batch_size | batch_size = cfg.Thor.batch_size | ||||
| @@ -200,11 +201,10 @@ class EmbeddingPostprocessor(nn.Cell): | |||||
| use_one_hot_embeddings=use_one_hot_embeddings, | use_one_hot_embeddings=use_one_hot_embeddings, | ||||
| initializer_range=initializer_range, | initializer_range=initializer_range, | ||||
| name='embedding_table', | name='embedding_table', | ||||
| is_expand=False, | |||||
| batch_size=batch_size, | batch_size=batch_size, | ||||
| damping=damping, | damping=damping, | ||||
| loss_scale=loss_scale, | loss_scale=loss_scale, | ||||
| frequency=1) | |||||
| frequency=frequency) | |||||
| self.shape_flat = (-1,) | self.shape_flat = (-1,) | ||||
| self.one_hot = P.OneHot() | self.one_hot = P.OneHot() | ||||
| self.on_value = Tensor(1.0, mstype.float32) | self.on_value = Tensor(1.0, mstype.float32) | ||||
| @@ -225,11 +225,10 @@ class EmbeddingPostprocessor(nn.Cell): | |||||
| use_one_hot_embeddings=use_one_hot_embeddings, | use_one_hot_embeddings=use_one_hot_embeddings, | ||||
| initializer_range=initializer_range, | initializer_range=initializer_range, | ||||
| name='full_position_embeddings', | name='full_position_embeddings', | ||||
| is_expand=False, | |||||
| batch_size=batch_size, | batch_size=batch_size, | ||||
| damping=damping, | damping=damping, | ||||
| loss_scale=loss_scale, | loss_scale=loss_scale, | ||||
| frequency=1) | |||||
| frequency=frequency) | |||||
| self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32)) | self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32)) | ||||
| self.layernorm = nn.LayerNorm((embedding_size,)) | self.layernorm = nn.LayerNorm((embedding_size,)) | ||||
| @@ -274,7 +273,7 @@ class BertOutput(nn.Cell): | |||||
| bias_init='zeros', | bias_init='zeros', | ||||
| damping=damping, | damping=damping, | ||||
| loss_scale=loss_scale, | loss_scale=loss_scale, | ||||
| frequency=1, | |||||
| frequency=frequency, | |||||
| activation=None, | activation=None, | ||||
| batch_size=batch_size).to_float(compute_type) | batch_size=batch_size).to_float(compute_type) | ||||
| self.dropout = nn.Dropout(1 - dropout_prob) | self.dropout = nn.Dropout(1 - dropout_prob) | ||||
| @@ -488,7 +487,7 @@ class BertAttention(nn.Cell): | |||||
| bias_init='zeros', | bias_init='zeros', | ||||
| damping=damping, | damping=damping, | ||||
| loss_scale=loss_scale, | loss_scale=loss_scale, | ||||
| frequency=1, | |||||
| frequency=frequency, | |||||
| activation=query_act, | activation=query_act, | ||||
| batch_size=batch_size).to_float(compute_type) | batch_size=batch_size).to_float(compute_type) | ||||
| self.key_layer = Dense_Thor(in_channels=to_tensor_width, | self.key_layer = Dense_Thor(in_channels=to_tensor_width, | ||||
| @@ -498,7 +497,7 @@ class BertAttention(nn.Cell): | |||||
| bias_init='zeros', | bias_init='zeros', | ||||
| damping=damping, | damping=damping, | ||||
| loss_scale=loss_scale, | loss_scale=loss_scale, | ||||
| frequency=1, | |||||
| frequency=frequency, | |||||
| activation=key_act, | activation=key_act, | ||||
| batch_size=batch_size).to_float(compute_type) | batch_size=batch_size).to_float(compute_type) | ||||
| self.value_layer = Dense_Thor(in_channels=to_tensor_width, | self.value_layer = Dense_Thor(in_channels=to_tensor_width, | ||||
| @@ -508,7 +507,7 @@ class BertAttention(nn.Cell): | |||||
| bias_init='zeros', | bias_init='zeros', | ||||
| damping=damping, | damping=damping, | ||||
| loss_scale=loss_scale, | loss_scale=loss_scale, | ||||
| frequency=1, | |||||
| frequency=frequency, | |||||
| activation=value_act, | activation=value_act, | ||||
| batch_size=batch_size).to_float(compute_type) | batch_size=batch_size).to_float(compute_type) | ||||
| self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head) | self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head) | ||||
| @@ -764,7 +763,7 @@ class BertEncoderCell(nn.Cell): | |||||
| bias_init='zeros', | bias_init='zeros', | ||||
| damping=damping, | damping=damping, | ||||
| loss_scale=loss_scale, | loss_scale=loss_scale, | ||||
| frequency=1, | |||||
| frequency=frequency, | |||||
| activation=hidden_act, | activation=hidden_act, | ||||
| batch_size=batch_size).to_float(compute_type) | batch_size=batch_size).to_float(compute_type) | ||||
| self.output = BertOutput(in_channels=intermediate_size, | self.output = BertOutput(in_channels=intermediate_size, | ||||
| @@ -945,11 +944,10 @@ class BertModel(nn.Cell): | |||||
| use_one_hot_embeddings=use_one_hot_embeddings, | use_one_hot_embeddings=use_one_hot_embeddings, | ||||
| initializer_range=config.initializer_range, | initializer_range=config.initializer_range, | ||||
| name='embedding_table', | name='embedding_table', | ||||
| is_expand=True, | |||||
| batch_size=batch_size, | batch_size=batch_size, | ||||
| damping=damping, | damping=damping, | ||||
| loss_scale=loss_scale, | loss_scale=loss_scale, | ||||
| frequency=1) | |||||
| frequency=frequency) | |||||
| self.bert_embedding_postprocessor = EmbeddingPostprocessor( | self.bert_embedding_postprocessor = EmbeddingPostprocessor( | ||||
| embedding_size=self.embedding_size, | embedding_size=self.embedding_size, | ||||
| embedding_shape=output_embedding_shape, | embedding_shape=output_embedding_shape, | ||||
| @@ -991,7 +989,7 @@ class BertModel(nn.Cell): | |||||
| bias_init='zeros', | bias_init='zeros', | ||||
| damping=damping, | damping=damping, | ||||
| loss_scale=loss_scale, | loss_scale=loss_scale, | ||||
| frequency=1, | |||||
| frequency=frequency, | |||||
| activation="tanh", | activation="tanh", | ||||
| batch_size=batch_size).to_float(config.compute_type) | batch_size=batch_size).to_float(config.compute_type) | ||||
| self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config) | self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config) | ||||
| @@ -19,9 +19,6 @@ from easydict import EasyDict as edict | |||||
| cfg = edict({ | cfg = edict({ | ||||
| 'bert_network': 'large', | 'bert_network': 'large', | ||||
| 'loss_scale_value': 65536, | |||||
| 'scale_factor': 2, | |||||
| 'scale_window': 1000, | |||||
| 'optimizer': 'Thor', | 'optimizer': 'Thor', | ||||
| 'AdamWeightDecay': edict({ | 'AdamWeightDecay': edict({ | ||||
| 'learning_rate': 3e-5, | 'learning_rate': 3e-5, | ||||
| @@ -49,7 +46,7 @@ cfg = edict({ | |||||
| 'momentum': 0.9, | 'momentum': 0.9, | ||||
| 'weight_decay': 5e-4, | 'weight_decay': 5e-4, | ||||
| 'loss_scale': 1, | 'loss_scale': 1, | ||||
| 'frequency': 10, | |||||
| 'batch_size': 8, | |||||
| 'frequency': 100, | |||||
| 'batch_size': 12, | |||||
| }), | }), | ||||
| }) | }) | ||||
| @@ -16,7 +16,6 @@ | |||||
| Data operations, will be used in run_pretrain.py | Data operations, will be used in run_pretrain.py | ||||
| """ | """ | ||||
| import os | import os | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| import mindspore.dataset.engine.datasets as de | import mindspore.dataset.engine.datasets as de | ||||
| import mindspore.dataset.transforms.c_transforms as C | import mindspore.dataset.transforms.c_transforms as C | ||||
| @@ -37,7 +36,7 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None, | |||||
| columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", | columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", | ||||
| "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"], | "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"], | ||||
| shuffle=de.Shuffle.FILES if do_shuffle == "true" else False, | shuffle=de.Shuffle.FILES if do_shuffle == "true" else False, | ||||
| num_shards=device_num, shard_id=rank, shard_equal_rows=True) | |||||
| num_shards=device_num, shard_id=rank, shard_equal_rows=False) | |||||
| ori_dataset_size = ds.get_dataset_size() | ori_dataset_size = ds.get_dataset_size() | ||||
| print('origin dataset size: ', ori_dataset_size) | print('origin dataset size: ', ori_dataset_size) | ||||
| type_cast_op = C.TypeCast(mstype.int32) | type_cast_op = C.TypeCast(mstype.int32) | ||||
| @@ -80,7 +80,7 @@ def _tensors_cast_datatype(datatype, grad): | |||||
| return F.cast(grad, datatype) | return F.cast(grad, datatype) | ||||
| class DistributedGradReducerThor1(Cell): | |||||
| class DistributedGradReducerThor(Cell): | |||||
| """ | """ | ||||
| A distributed optimizer. | A distributed optimizer. | ||||
| @@ -154,7 +154,7 @@ class DistributedGradReducerThor1(Cell): | |||||
| """ | """ | ||||
| def __init__(self, parameters, group, mean=True, degree=None): | def __init__(self, parameters, group, mean=True, degree=None): | ||||
| super(DistributedGradReducerThor1, self).__init__(auto_prefix=False) | |||||
| super(DistributedGradReducerThor, self).__init__(auto_prefix=False) | |||||
| self.hyper_map = C.HyperMap() | self.hyper_map = C.HyperMap() | ||||
| self.mul = P.Mul() | self.mul = P.Mul() | ||||
| if degree is None: | if degree is None: | ||||
| @@ -168,7 +168,7 @@ class DistributedGradReducerThor1(Cell): | |||||
| _init_optimizer_allreduce(group) | _init_optimizer_allreduce(group) | ||||
| def construct(self, grads): | def construct(self, grads): | ||||
| """construct of DistributedGradReducerThor1""" | |||||
| """construct of DistributedGradReducerThor""" | |||||
| # In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the | # In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the | ||||
| # result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce, | # result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce, | ||||
| # and cast back after the operation. | # and cast back after the operation. | ||||
| @@ -58,7 +58,7 @@ def get_poly_lr(global_step, lr_init, lr_end, lr_max, warmup_steps, total_steps, | |||||
| # bert kfac hyperparam setting | # bert kfac hyperparam setting | ||||
| def get_bert_lr(): | def get_bert_lr(): | ||||
| learning_rate = Tensor( | learning_rate = Tensor( | ||||
| get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=4e-4, warmup_steps=0, total_steps=30000, | |||||
| get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=3.1e-3, warmup_steps=0, total_steps=30000, | |||||
| poly_power=1)) | poly_power=1)) | ||||
| return learning_rate | return learning_rate | ||||
| @@ -46,9 +46,8 @@ def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): | |||||
| class THOR(Optimizer): | class THOR(Optimizer): | ||||
| """THOR""" | """THOR""" | ||||
| def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0, | |||||
| loss_scale=1.0, num_hidden_layers=24, batch_size=12, damping=0.03, frequency=10, | |||||
| def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, weight_decay=0.0, | |||||
| loss_scale=1.0, num_hidden_layers=24, batch_size=12, damping=0.03, | |||||
| decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()): | decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()): | ||||
| super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale) | super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale) | ||||
| if isinstance(momentum, float) and momentum < 0.0: | if isinstance(momentum, float) and momentum < 0.0: | ||||
| @@ -60,8 +59,6 @@ class THOR(Optimizer): | |||||
| self.opt = P.ApplyMomentum() | self.opt = P.ApplyMomentum() | ||||
| self.matrix_A = ParameterTuple(matrix_A) | self.matrix_A = ParameterTuple(matrix_A) | ||||
| self.matrix_G = ParameterTuple(matrix_G) | self.matrix_G = ParameterTuple(matrix_G) | ||||
| self.A_inv_max = ParameterTuple(A_inv_max) | |||||
| self.G_inv_max = ParameterTuple(G_inv_max) | |||||
| self.matmul = P.MatMul() | self.matmul = P.MatMul() | ||||
| self.transpose = P.Transpose() | self.transpose = P.Transpose() | ||||
| self.shape = P.Shape() | self.shape = P.Shape() | ||||
| @@ -70,16 +67,8 @@ class THOR(Optimizer): | |||||
| self.gather = P.GatherV2() | self.gather = P.GatherV2() | ||||
| self.matrix_A_inv = () | self.matrix_A_inv = () | ||||
| self.matrix_G_inv = () | self.matrix_G_inv = () | ||||
| self.matrix_max_inv = () | |||||
| self.num_hidden_layers = num_hidden_layers | self.num_hidden_layers = num_hidden_layers | ||||
| fc_layer_num = num_hidden_layers * 6 + 5 | |||||
| for i in range(fc_layer_num): | |||||
| self.matrix_max_inv = self.matrix_max_inv + ( | |||||
| Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),) | |||||
| self.log = P.Log() | |||||
| self.exp = P.Exp() | |||||
| self.sqrt = P.Sqrt() | self.sqrt = P.Sqrt() | ||||
| self.matrix_max_inv = ParameterTuple(self.matrix_max_inv) | |||||
| self.assign = P.Assign() | self.assign = P.Assign() | ||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.thor = True | self.thor = True | ||||
| @@ -90,7 +79,6 @@ class THOR(Optimizer): | |||||
| self.inv = P.Inv() | self.inv = P.Inv() | ||||
| self.batch_size = batch_size | self.batch_size = batch_size | ||||
| self.damping = damping | self.damping = damping | ||||
| self.freq = Tensor(frequency, mstype.int32) | |||||
| self.one = Tensor(1, mstype.int32) | self.one = Tensor(1, mstype.int32) | ||||
| self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False) | self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False) | ||||
| @@ -106,26 +94,20 @@ class THOR(Optimizer): | |||||
| g = gradients[em_idx] | g = gradients[em_idx] | ||||
| matrix_idx = em_idx | matrix_idx = em_idx | ||||
| temp_a_ori = self.matrix_A[matrix_idx] | temp_a_ori = self.matrix_A[matrix_idx] | ||||
| temp_a = self.expand(temp_a_ori, 1) | |||||
| temp_g = self.matrix_G[matrix_idx] | temp_g = self.matrix_G[matrix_idx] | ||||
| G_max = self.G_inv_max[matrix_idx] | |||||
| temp_g = self.cast(temp_g, mstype.float32) | |||||
| matrix_G_inv_max = self.log(G_max) | |||||
| matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) | |||||
| matrix_G_inv_max = self.exp(matrix_G_inv_max) | |||||
| temp_g = self.mul(temp_g, matrix_G_inv_max) | |||||
| g = self.mul(temp_a, g) | |||||
| g = self.cast(g, mstype.float16) | |||||
| temp_a_ori = F.depend(temp_a_ori, g) | |||||
| temp_g = F.depend(temp_g, g) | |||||
| temp_a = self.expand(temp_a_ori, 1) | |||||
| temp_a = self.cast(temp_a, mstype.float16) | |||||
| temp_g = self.cast(temp_g, mstype.float16) | temp_g = self.cast(temp_g, mstype.float16) | ||||
| g = self.cast(g, mstype.float16) | |||||
| g = self.mul(temp_a, g) | |||||
| g = self.matmul(g, temp_g) | g = self.matmul(g, temp_g) | ||||
| g = self.cast(g, mstype.float32) | g = self.cast(g, mstype.float32) | ||||
| g = self.mul(g, G_max) | |||||
| fake_A = self.assign(self.matrix_A[matrix_idx], temp_a_ori) | fake_A = self.assign(self.matrix_A[matrix_idx], temp_a_ori) | ||||
| fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | ||||
| fake_max = self.assign(self.matrix_max_inv[matrix_idx], G_max) | |||||
| g = F.depend(g, fake_A) | g = F.depend(g, fake_A) | ||||
| g = F.depend(g, fake_G) | g = F.depend(g, fake_G) | ||||
| g = F.depend(g, fake_max) | |||||
| new_grads = new_grads + (g,) | new_grads = new_grads + (g,) | ||||
| # process bert_embedding_postprocessor.layernorm | # process bert_embedding_postprocessor.layernorm | ||||
| grad_idx = 3 | grad_idx = 3 | ||||
| @@ -180,32 +162,18 @@ class THOR(Optimizer): | |||||
| matrix_idx = 6 * i + offset_idx + 3 | matrix_idx = 6 * i + offset_idx + 3 | ||||
| temp_a = self.matrix_A[matrix_idx] | temp_a = self.matrix_A[matrix_idx] | ||||
| temp_g = self.matrix_G[matrix_idx] | temp_g = self.matrix_G[matrix_idx] | ||||
| temp_a = self.cast(temp_a, mstype.float32) | |||||
| temp_g = self.cast(temp_g, mstype.float32) | |||||
| matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) | |||||
| matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) | |||||
| matrix_A_inv_max = self.exp(matrix_A_inv_max) | |||||
| temp_a = self.mul(temp_a, matrix_A_inv_max) | |||||
| matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) | |||||
| matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) | |||||
| matrix_G_inv_max = self.exp(matrix_G_inv_max) | |||||
| temp_g = self.mul(temp_g, matrix_G_inv_max) | |||||
| temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) | |||||
| temp_a = F.depend(temp_a, g) | |||||
| temp_g = F.depend(temp_g, g) | |||||
| temp_a = self.cast(temp_a, mstype.float16) | temp_a = self.cast(temp_a, mstype.float16) | ||||
| temp_g = self.cast(temp_g, mstype.float16) | temp_g = self.cast(temp_g, mstype.float16) | ||||
| g = self.cast(g, mstype.float16) | g = self.cast(g, mstype.float16) | ||||
| g = self.matmul(temp_g, g) | g = self.matmul(temp_g, g) | ||||
| g = self.matmul(g, temp_a) | g = self.matmul(g, temp_a) | ||||
| g = self.cast(g, mstype.float32) | g = self.cast(g, mstype.float32) | ||||
| g = self.mul(g, temp_max) | |||||
| fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) | fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) | ||||
| fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | ||||
| fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) | |||||
| g = F.depend(g, fake_A) | g = F.depend(g, fake_A) | ||||
| g = F.depend(g, fake_G) | g = F.depend(g, fake_G) | ||||
| g = F.depend(g, fake_max) | |||||
| new_grads = new_grads + (g,) | new_grads = new_grads + (g,) | ||||
| new_grads = new_grads + (gradients[grad_idx + 1],) | new_grads = new_grads + (gradients[grad_idx + 1],) | ||||
| @@ -216,32 +184,18 @@ class THOR(Optimizer): | |||||
| pooler_bias = gradients[pooler_layer_idx + 1] | pooler_bias = gradients[pooler_layer_idx + 1] | ||||
| temp_a = self.matrix_A[matrix_idx] | temp_a = self.matrix_A[matrix_idx] | ||||
| temp_g = self.matrix_G[matrix_idx] | temp_g = self.matrix_G[matrix_idx] | ||||
| temp_a = self.cast(temp_a, mstype.float32) | |||||
| temp_g = self.cast(temp_g, mstype.float32) | |||||
| matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) | |||||
| matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) | |||||
| matrix_A_inv_max = self.exp(matrix_A_inv_max) | |||||
| temp_a = self.mul(temp_a, matrix_A_inv_max) | |||||
| matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) | |||||
| matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) | |||||
| matrix_G_inv_max = self.exp(matrix_G_inv_max) | |||||
| temp_g = self.mul(temp_g, matrix_G_inv_max) | |||||
| temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) | |||||
| temp_a = F.depend(temp_a, g) | |||||
| temp_g = F.depend(temp_g, g) | |||||
| temp_a = self.cast(temp_a, mstype.float16) | temp_a = self.cast(temp_a, mstype.float16) | ||||
| temp_g = self.cast(temp_g, mstype.float16) | temp_g = self.cast(temp_g, mstype.float16) | ||||
| g = self.cast(g, mstype.float16) | g = self.cast(g, mstype.float16) | ||||
| g = self.matmul(temp_g, g) | g = self.matmul(temp_g, g) | ||||
| g = self.matmul(g, temp_a) | g = self.matmul(g, temp_a) | ||||
| g = self.cast(g, mstype.float32) | g = self.cast(g, mstype.float32) | ||||
| g = self.mul(g, temp_max) | |||||
| fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) | fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) | ||||
| fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | ||||
| fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) | |||||
| g = F.depend(g, fake_A) | g = F.depend(g, fake_A) | ||||
| g = F.depend(g, fake_G) | g = F.depend(g, fake_G) | ||||
| g = F.depend(g, fake_max) | |||||
| new_grads = new_grads + (g, pooler_bias) | new_grads = new_grads + (g, pooler_bias) | ||||
| # for cls1 fc layer: mlm | # for cls1 fc layer: mlm | ||||
| @@ -251,38 +205,26 @@ class THOR(Optimizer): | |||||
| mlm_bias = gradients[mlm_fc_idx + 1] | mlm_bias = gradients[mlm_fc_idx + 1] | ||||
| temp_a = self.matrix_A[matrix_idx] | temp_a = self.matrix_A[matrix_idx] | ||||
| temp_g = self.matrix_G[matrix_idx] | temp_g = self.matrix_G[matrix_idx] | ||||
| temp_a = self.cast(temp_a, mstype.float32) | |||||
| temp_g = self.cast(temp_g, mstype.float32) | |||||
| matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) | |||||
| matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) | |||||
| matrix_A_inv_max = self.exp(matrix_A_inv_max) | |||||
| temp_a = self.mul(temp_a, matrix_A_inv_max) | |||||
| matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) | |||||
| matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) | |||||
| matrix_G_inv_max = self.exp(matrix_G_inv_max) | |||||
| temp_g = self.mul(temp_g, matrix_G_inv_max) | |||||
| temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) | |||||
| temp_a = F.depend(temp_a, g) | |||||
| temp_g = F.depend(temp_g, g) | |||||
| temp_a = self.cast(temp_a, mstype.float16) | temp_a = self.cast(temp_a, mstype.float16) | ||||
| temp_g = self.cast(temp_g, mstype.float16) | temp_g = self.cast(temp_g, mstype.float16) | ||||
| g = self.cast(g, mstype.float16) | g = self.cast(g, mstype.float16) | ||||
| g = self.matmul(temp_g, g) | g = self.matmul(temp_g, g) | ||||
| g = self.matmul(g, temp_a) | g = self.matmul(g, temp_a) | ||||
| g = self.cast(g, mstype.float32) | g = self.cast(g, mstype.float32) | ||||
| g = self.mul(g, temp_max) | |||||
| # add bert.cls1.output_bias grad | |||||
| fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) | fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) | ||||
| fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | ||||
| fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) | |||||
| g = F.depend(g, fake_A) | g = F.depend(g, fake_A) | ||||
| g = F.depend(g, fake_G) | g = F.depend(g, fake_G) | ||||
| g = F.depend(g, fake_max) | |||||
| new_grads = new_grads + (gradients[mlm_fc_idx - 1],) | new_grads = new_grads + (gradients[mlm_fc_idx - 1],) | ||||
| new_grads = new_grads + (g, mlm_bias) | new_grads = new_grads + (g, mlm_bias) | ||||
| # add bert.cls1.layernorm grad | # add bert.cls1.layernorm grad | ||||
| begin_idx = mlm_fc_idx + 2 | begin_idx = mlm_fc_idx + 2 | ||||
| end_idx = mlm_fc_idx + 4 | end_idx = mlm_fc_idx + 4 | ||||
| new_grads = new_grads + gradients[begin_idx: end_idx] | new_grads = new_grads + gradients[begin_idx: end_idx] | ||||
| lenth = len(gradients) | lenth = len(gradients) | ||||
| new_grads = new_grads + gradients[lenth - 2: lenth] | new_grads = new_grads + gradients[lenth - 2: lenth] | ||||
| gradients = new_grads | gradients = new_grads | ||||
| @@ -293,15 +235,16 @@ class THOR(Optimizer): | |||||
| g = gradients[em_idx] | g = gradients[em_idx] | ||||
| matrix_idx = em_idx | matrix_idx = em_idx | ||||
| temp_a = self.matrix_A[matrix_idx] | temp_a = self.matrix_A[matrix_idx] | ||||
| temp_a = self.expand(temp_a, 1) | |||||
| temp_g = self.matrix_G[matrix_idx] | temp_g = self.matrix_G[matrix_idx] | ||||
| matrix_max = self.matrix_max_inv[matrix_idx] | |||||
| g = self.mul(temp_a, g) | |||||
| temp_a = F.depend(temp_a, g) | |||||
| temp_g = F.depend(temp_g, g) | |||||
| temp_a = self.expand(temp_a, 1) | |||||
| temp_a = self.cast(temp_a, mstype.float16) | |||||
| temp_g = self.cast(temp_g, mstype.float16) | temp_g = self.cast(temp_g, mstype.float16) | ||||
| g = self.cast(g, mstype.float16) | g = self.cast(g, mstype.float16) | ||||
| g = self.mul(temp_a, g) | |||||
| g = self.matmul(g, temp_g) | g = self.matmul(g, temp_g) | ||||
| g = self.cast(g, mstype.float32) | g = self.cast(g, mstype.float32) | ||||
| g = self.mul(g, matrix_max) | |||||
| new_grads = new_grads + (g,) | new_grads = new_grads + (g,) | ||||
| # process bert_embedding_postprocessor.layernorm | # process bert_embedding_postprocessor.layernorm | ||||
| grad_idx = 3 | grad_idx = 3 | ||||
| @@ -356,15 +299,14 @@ class THOR(Optimizer): | |||||
| matrix_idx = 6 * i + offset_idx + 3 | matrix_idx = 6 * i + offset_idx + 3 | ||||
| temp_a = self.matrix_A[matrix_idx] | temp_a = self.matrix_A[matrix_idx] | ||||
| temp_g = self.matrix_G[matrix_idx] | temp_g = self.matrix_G[matrix_idx] | ||||
| matrix_max = self.matrix_max_inv[matrix_idx] | |||||
| temp_a = F.depend(temp_a, g) | |||||
| temp_g = F.depend(temp_g, g) | |||||
| temp_a = self.cast(temp_a, mstype.float16) | temp_a = self.cast(temp_a, mstype.float16) | ||||
| temp_g = self.cast(temp_g, mstype.float16) | temp_g = self.cast(temp_g, mstype.float16) | ||||
| g = self.cast(g, mstype.float16) | g = self.cast(g, mstype.float16) | ||||
| g = self.matmul(temp_g, g) | g = self.matmul(temp_g, g) | ||||
| g = self.matmul(g, temp_a) | g = self.matmul(g, temp_a) | ||||
| g = self.cast(g, mstype.float32) | g = self.cast(g, mstype.float32) | ||||
| g = self.mul(g, matrix_max) | |||||
| new_grads = new_grads + (g,) | new_grads = new_grads + (g,) | ||||
| new_grads = new_grads + (gradients[grad_idx + 1],) | new_grads = new_grads + (gradients[grad_idx + 1],) | ||||
| @@ -375,15 +317,14 @@ class THOR(Optimizer): | |||||
| pooler_bias = gradients[pooler_layer_idx + 1] | pooler_bias = gradients[pooler_layer_idx + 1] | ||||
| temp_a = self.matrix_A[matrix_idx] | temp_a = self.matrix_A[matrix_idx] | ||||
| temp_g = self.matrix_G[matrix_idx] | temp_g = self.matrix_G[matrix_idx] | ||||
| matrix_max = self.matrix_max_inv[matrix_idx] | |||||
| temp_a = F.depend(temp_a, g) | |||||
| temp_g = F.depend(temp_g, g) | |||||
| temp_a = self.cast(temp_a, mstype.float16) | temp_a = self.cast(temp_a, mstype.float16) | ||||
| temp_g = self.cast(temp_g, mstype.float16) | temp_g = self.cast(temp_g, mstype.float16) | ||||
| g = self.cast(g, mstype.float16) | g = self.cast(g, mstype.float16) | ||||
| g = self.matmul(temp_g, g) | g = self.matmul(temp_g, g) | ||||
| g = self.matmul(g, temp_a) | g = self.matmul(g, temp_a) | ||||
| g = self.cast(g, mstype.float32) | g = self.cast(g, mstype.float32) | ||||
| g = self.mul(g, matrix_max) | |||||
| new_grads = new_grads + (g, pooler_bias) | new_grads = new_grads + (g, pooler_bias) | ||||
| # for cls1 fc layer: mlm | # for cls1 fc layer: mlm | ||||
| @@ -393,15 +334,14 @@ class THOR(Optimizer): | |||||
| mlm_bias = gradients[mlm_fc_idx + 1] | mlm_bias = gradients[mlm_fc_idx + 1] | ||||
| temp_a = self.matrix_A[matrix_idx] | temp_a = self.matrix_A[matrix_idx] | ||||
| temp_g = self.matrix_G[matrix_idx] | temp_g = self.matrix_G[matrix_idx] | ||||
| matrix_max = self.matrix_max_inv[matrix_idx] | |||||
| temp_a = F.depend(temp_a, g) | |||||
| temp_g = F.depend(temp_g, g) | |||||
| temp_a = self.cast(temp_a, mstype.float16) | temp_a = self.cast(temp_a, mstype.float16) | ||||
| temp_g = self.cast(temp_g, mstype.float16) | temp_g = self.cast(temp_g, mstype.float16) | ||||
| g = self.cast(g, mstype.float16) | g = self.cast(g, mstype.float16) | ||||
| g = self.matmul(temp_g, g) | g = self.matmul(temp_g, g) | ||||
| g = self.matmul(g, temp_a) | g = self.matmul(g, temp_a) | ||||
| g = self.cast(g, mstype.float32) | g = self.cast(g, mstype.float32) | ||||
| g = self.mul(g, matrix_max) | |||||
| # add bert.cls1.output_bias grad | # add bert.cls1.output_bias grad | ||||
| new_grads = new_grads + (gradients[mlm_fc_idx - 1],) | new_grads = new_grads + (gradients[mlm_fc_idx - 1],) | ||||
| new_grads = new_grads + (g, mlm_bias) | new_grads = new_grads + (g, mlm_bias) | ||||
| @@ -409,6 +349,7 @@ class THOR(Optimizer): | |||||
| begin_idx = mlm_fc_idx + 2 | begin_idx = mlm_fc_idx + 2 | ||||
| end_idx = mlm_fc_idx + 4 | end_idx = mlm_fc_idx + 4 | ||||
| new_grads = new_grads + gradients[begin_idx: end_idx] | new_grads = new_grads + gradients[begin_idx: end_idx] | ||||
| lenth = len(gradients) | lenth = len(gradients) | ||||
| new_grads = new_grads + gradients[lenth - 2: lenth] | new_grads = new_grads + gradients[lenth - 2: lenth] | ||||
| gradients = new_grads | gradients = new_grads | ||||
| @@ -21,7 +21,7 @@ from mindspore.common.tensor import Tensor | |||||
| from mindspore.nn.optim.optimizer import Optimizer | from mindspore.nn.optim.optimizer import Optimizer | ||||
| from mindspore.ops import functional as F, composite as C, operations as P | from mindspore.ops import functional as F, composite as C, operations as P | ||||
| from mindspore.parallel._utils import _get_device_num, _get_mirror_mean | from mindspore.parallel._utils import _get_device_num, _get_mirror_mean | ||||
| from .grad_reducer_thor1 import DistributedGradReducerThor1 | |||||
| from .grad_reducer_thor import DistributedGradReducerThor | |||||
| momentum_opt = C.MultitypeFuncGraph("momentum_opt") | momentum_opt = C.MultitypeFuncGraph("momentum_opt") | ||||
| @@ -48,9 +48,8 @@ def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): | |||||
| class THOR(Optimizer): | class THOR(Optimizer): | ||||
| """THOR""" | """THOR""" | ||||
| def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0, | |||||
| loss_scale=1.0, num_hidden_layers=24, batch_size=12, damping=0.03, frequency=10, | |||||
| def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, weight_decay=0.0, | |||||
| loss_scale=1.0, num_hidden_layers=24, batch_size=12, damping=0.03, | |||||
| decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()): | decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()): | ||||
| super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale) | super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale) | ||||
| if isinstance(momentum, float) and momentum < 0.0: | if isinstance(momentum, float) and momentum < 0.0: | ||||
| @@ -62,8 +61,6 @@ class THOR(Optimizer): | |||||
| self.opt = P.ApplyMomentum() | self.opt = P.ApplyMomentum() | ||||
| self.matrix_A = ParameterTuple(matrix_A) | self.matrix_A = ParameterTuple(matrix_A) | ||||
| self.matrix_G = ParameterTuple(matrix_G) | self.matrix_G = ParameterTuple(matrix_G) | ||||
| self.A_inv_max = ParameterTuple(A_inv_max) | |||||
| self.G_inv_max = ParameterTuple(G_inv_max) | |||||
| self.matmul = P.MatMul() | self.matmul = P.MatMul() | ||||
| self.transpose = P.Transpose() | self.transpose = P.Transpose() | ||||
| self.shape = P.Shape() | self.shape = P.Shape() | ||||
| @@ -72,16 +69,8 @@ class THOR(Optimizer): | |||||
| self.gather = P.GatherV2() | self.gather = P.GatherV2() | ||||
| self.matrix_A_inv = () | self.matrix_A_inv = () | ||||
| self.matrix_G_inv = () | self.matrix_G_inv = () | ||||
| self.matrix_max_inv = () | |||||
| self.num_hidden_layers = num_hidden_layers | self.num_hidden_layers = num_hidden_layers | ||||
| fc_layer_num = num_hidden_layers * 6 + 5 | |||||
| for i in range(fc_layer_num): | |||||
| self.matrix_max_inv = self.matrix_max_inv + ( | |||||
| Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),) | |||||
| self.log = P.Log() | |||||
| self.exp = P.Exp() | |||||
| self.sqrt = P.Sqrt() | self.sqrt = P.Sqrt() | ||||
| self.matrix_max_inv = ParameterTuple(self.matrix_max_inv) | |||||
| self.assign = P.Assign() | self.assign = P.Assign() | ||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.thor = True | self.thor = True | ||||
| @@ -92,12 +81,11 @@ class THOR(Optimizer): | |||||
| self.inv = P.Inv() | self.inv = P.Inv() | ||||
| self.batch_size = batch_size | self.batch_size = batch_size | ||||
| self.damping = damping | self.damping = damping | ||||
| self.freq = Tensor(frequency, mstype.int32) | |||||
| self.one = Tensor(1, mstype.int32) | self.one = Tensor(1, mstype.int32) | ||||
| self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False) | self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False) | ||||
| mean = _get_mirror_mean() | mean = _get_mirror_mean() | ||||
| degree = _get_device_num() | degree = _get_device_num() | ||||
| self.grad_reducer_g = DistributedGradReducerThor1(self.parameters, 3, mean, degree) | |||||
| self.grad_reducer_g = DistributedGradReducerThor(self.parameters, 3, mean, degree) | |||||
| def construct(self, gradients): | def construct(self, gradients): | ||||
| """construct of THOR""" | """construct of THOR""" | ||||
| @@ -111,26 +99,20 @@ class THOR(Optimizer): | |||||
| g = gradients[em_idx] | g = gradients[em_idx] | ||||
| matrix_idx = em_idx | matrix_idx = em_idx | ||||
| temp_a_ori = self.matrix_A[matrix_idx] | temp_a_ori = self.matrix_A[matrix_idx] | ||||
| temp_a = self.expand(temp_a_ori, 1) | |||||
| temp_g = self.matrix_G[matrix_idx] | temp_g = self.matrix_G[matrix_idx] | ||||
| G_max = self.G_inv_max[matrix_idx] | |||||
| temp_g = self.cast(temp_g, mstype.float32) | |||||
| matrix_G_inv_max = self.log(G_max) | |||||
| matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) | |||||
| matrix_G_inv_max = self.exp(matrix_G_inv_max) | |||||
| temp_g = self.mul(temp_g, matrix_G_inv_max) | |||||
| g = self.mul(temp_a, g) | |||||
| g = self.cast(g, mstype.float16) | |||||
| temp_a_ori = F.depend(temp_a_ori, g) | |||||
| temp_g = F.depend(temp_g, g) | |||||
| temp_a = self.expand(temp_a_ori, 1) | |||||
| temp_a = self.cast(temp_a, mstype.float16) | |||||
| temp_g = self.cast(temp_g, mstype.float16) | temp_g = self.cast(temp_g, mstype.float16) | ||||
| g = self.cast(g, mstype.float16) | |||||
| g = self.mul(temp_a, g) | |||||
| g = self.matmul(g, temp_g) | g = self.matmul(g, temp_g) | ||||
| g = self.cast(g, mstype.float32) | g = self.cast(g, mstype.float32) | ||||
| g = self.mul(g, G_max) | |||||
| fake_A = self.assign(self.matrix_A[matrix_idx], temp_a_ori) | fake_A = self.assign(self.matrix_A[matrix_idx], temp_a_ori) | ||||
| fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | ||||
| fake_max = self.assign(self.matrix_max_inv[matrix_idx], G_max) | |||||
| g = F.depend(g, fake_A) | g = F.depend(g, fake_A) | ||||
| g = F.depend(g, fake_G) | g = F.depend(g, fake_G) | ||||
| g = F.depend(g, fake_max) | |||||
| new_grads = new_grads + (g,) | new_grads = new_grads + (g,) | ||||
| # process bert_embedding_postprocessor.layernorm | # process bert_embedding_postprocessor.layernorm | ||||
| grad_idx = 3 | grad_idx = 3 | ||||
| @@ -185,32 +167,18 @@ class THOR(Optimizer): | |||||
| matrix_idx = 6 * i + offset_idx + 3 | matrix_idx = 6 * i + offset_idx + 3 | ||||
| temp_a = self.matrix_A[matrix_idx] | temp_a = self.matrix_A[matrix_idx] | ||||
| temp_g = self.matrix_G[matrix_idx] | temp_g = self.matrix_G[matrix_idx] | ||||
| temp_a = self.cast(temp_a, mstype.float32) | |||||
| temp_g = self.cast(temp_g, mstype.float32) | |||||
| matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) | |||||
| matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) | |||||
| matrix_A_inv_max = self.exp(matrix_A_inv_max) | |||||
| temp_a = self.mul(temp_a, matrix_A_inv_max) | |||||
| matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) | |||||
| matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) | |||||
| matrix_G_inv_max = self.exp(matrix_G_inv_max) | |||||
| temp_g = self.mul(temp_g, matrix_G_inv_max) | |||||
| temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) | |||||
| temp_a = F.depend(temp_a, g) | |||||
| temp_g = F.depend(temp_g, g) | |||||
| temp_a = self.cast(temp_a, mstype.float16) | temp_a = self.cast(temp_a, mstype.float16) | ||||
| temp_g = self.cast(temp_g, mstype.float16) | temp_g = self.cast(temp_g, mstype.float16) | ||||
| g = self.cast(g, mstype.float16) | g = self.cast(g, mstype.float16) | ||||
| g = self.matmul(temp_g, g) | g = self.matmul(temp_g, g) | ||||
| g = self.matmul(g, temp_a) | g = self.matmul(g, temp_a) | ||||
| g = self.cast(g, mstype.float32) | g = self.cast(g, mstype.float32) | ||||
| g = self.mul(g, temp_max) | |||||
| fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) | fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) | ||||
| fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | ||||
| fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) | |||||
| g = F.depend(g, fake_A) | g = F.depend(g, fake_A) | ||||
| g = F.depend(g, fake_G) | g = F.depend(g, fake_G) | ||||
| g = F.depend(g, fake_max) | |||||
| new_grads = new_grads + (g,) | new_grads = new_grads + (g,) | ||||
| new_grads = new_grads + (gradients[grad_idx + 1],) | new_grads = new_grads + (gradients[grad_idx + 1],) | ||||
| @@ -221,32 +189,18 @@ class THOR(Optimizer): | |||||
| pooler_bias = gradients[pooler_layer_idx + 1] | pooler_bias = gradients[pooler_layer_idx + 1] | ||||
| temp_a = self.matrix_A[matrix_idx] | temp_a = self.matrix_A[matrix_idx] | ||||
| temp_g = self.matrix_G[matrix_idx] | temp_g = self.matrix_G[matrix_idx] | ||||
| temp_a = self.cast(temp_a, mstype.float32) | |||||
| temp_g = self.cast(temp_g, mstype.float32) | |||||
| matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) | |||||
| matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) | |||||
| matrix_A_inv_max = self.exp(matrix_A_inv_max) | |||||
| temp_a = self.mul(temp_a, matrix_A_inv_max) | |||||
| matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) | |||||
| matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) | |||||
| matrix_G_inv_max = self.exp(matrix_G_inv_max) | |||||
| temp_g = self.mul(temp_g, matrix_G_inv_max) | |||||
| temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) | |||||
| temp_a = F.depend(temp_a, g) | |||||
| temp_g = F.depend(temp_g, g) | |||||
| temp_a = self.cast(temp_a, mstype.float16) | temp_a = self.cast(temp_a, mstype.float16) | ||||
| temp_g = self.cast(temp_g, mstype.float16) | temp_g = self.cast(temp_g, mstype.float16) | ||||
| g = self.cast(g, mstype.float16) | g = self.cast(g, mstype.float16) | ||||
| g = self.matmul(temp_g, g) | g = self.matmul(temp_g, g) | ||||
| g = self.matmul(g, temp_a) | g = self.matmul(g, temp_a) | ||||
| g = self.cast(g, mstype.float32) | g = self.cast(g, mstype.float32) | ||||
| g = self.mul(g, temp_max) | |||||
| fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) | fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) | ||||
| fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | ||||
| fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) | |||||
| g = F.depend(g, fake_A) | g = F.depend(g, fake_A) | ||||
| g = F.depend(g, fake_G) | g = F.depend(g, fake_G) | ||||
| g = F.depend(g, fake_max) | |||||
| new_grads = new_grads + (g, pooler_bias) | new_grads = new_grads + (g, pooler_bias) | ||||
| # for cls1 fc layer: mlm | # for cls1 fc layer: mlm | ||||
| @@ -256,38 +210,26 @@ class THOR(Optimizer): | |||||
| mlm_bias = gradients[mlm_fc_idx + 1] | mlm_bias = gradients[mlm_fc_idx + 1] | ||||
| temp_a = self.matrix_A[matrix_idx] | temp_a = self.matrix_A[matrix_idx] | ||||
| temp_g = self.matrix_G[matrix_idx] | temp_g = self.matrix_G[matrix_idx] | ||||
| temp_a = self.cast(temp_a, mstype.float32) | |||||
| temp_g = self.cast(temp_g, mstype.float32) | |||||
| matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) | |||||
| matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) | |||||
| matrix_A_inv_max = self.exp(matrix_A_inv_max) | |||||
| temp_a = self.mul(temp_a, matrix_A_inv_max) | |||||
| matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) | |||||
| matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) | |||||
| matrix_G_inv_max = self.exp(matrix_G_inv_max) | |||||
| temp_g = self.mul(temp_g, matrix_G_inv_max) | |||||
| temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) | |||||
| temp_a = F.depend(temp_a, g) | |||||
| temp_g = F.depend(temp_g, g) | |||||
| temp_a = self.cast(temp_a, mstype.float16) | temp_a = self.cast(temp_a, mstype.float16) | ||||
| temp_g = self.cast(temp_g, mstype.float16) | temp_g = self.cast(temp_g, mstype.float16) | ||||
| g = self.cast(g, mstype.float16) | g = self.cast(g, mstype.float16) | ||||
| g = self.matmul(temp_g, g) | g = self.matmul(temp_g, g) | ||||
| g = self.matmul(g, temp_a) | g = self.matmul(g, temp_a) | ||||
| g = self.cast(g, mstype.float32) | g = self.cast(g, mstype.float32) | ||||
| g = self.mul(g, temp_max) | |||||
| # add bert.cls1.output_bias grad | |||||
| fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) | fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) | ||||
| fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | ||||
| fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) | |||||
| g = F.depend(g, fake_A) | g = F.depend(g, fake_A) | ||||
| g = F.depend(g, fake_G) | g = F.depend(g, fake_G) | ||||
| g = F.depend(g, fake_max) | |||||
| new_grads = new_grads + (gradients[mlm_fc_idx - 1],) | new_grads = new_grads + (gradients[mlm_fc_idx - 1],) | ||||
| new_grads = new_grads + (g, mlm_bias) | new_grads = new_grads + (g, mlm_bias) | ||||
| # add bert.cls1.layernorm grad | # add bert.cls1.layernorm grad | ||||
| begin_idx = mlm_fc_idx + 2 | begin_idx = mlm_fc_idx + 2 | ||||
| end_idx = mlm_fc_idx + 4 | end_idx = mlm_fc_idx + 4 | ||||
| new_grads = new_grads + gradients[begin_idx: end_idx] | new_grads = new_grads + gradients[begin_idx: end_idx] | ||||
| lenth = len(gradients) | lenth = len(gradients) | ||||
| new_grads = new_grads + gradients[lenth - 2: lenth] | new_grads = new_grads + gradients[lenth - 2: lenth] | ||||
| gradients = new_grads | gradients = new_grads | ||||
| @@ -299,15 +241,16 @@ class THOR(Optimizer): | |||||
| g = gradients[em_idx] | g = gradients[em_idx] | ||||
| matrix_idx = em_idx | matrix_idx = em_idx | ||||
| temp_a = self.matrix_A[matrix_idx] | temp_a = self.matrix_A[matrix_idx] | ||||
| temp_a = self.expand(temp_a, 1) | |||||
| temp_g = self.matrix_G[matrix_idx] | temp_g = self.matrix_G[matrix_idx] | ||||
| matrix_max = self.matrix_max_inv[matrix_idx] | |||||
| g = self.mul(temp_a, g) | |||||
| temp_a = F.depend(temp_a, g) | |||||
| temp_g = F.depend(temp_g, g) | |||||
| temp_a = self.expand(temp_a, 1) | |||||
| temp_a = self.cast(temp_a, mstype.float16) | |||||
| temp_g = self.cast(temp_g, mstype.float16) | temp_g = self.cast(temp_g, mstype.float16) | ||||
| g = self.cast(g, mstype.float16) | g = self.cast(g, mstype.float16) | ||||
| g = self.mul(temp_a, g) | |||||
| g = self.matmul(g, temp_g) | g = self.matmul(g, temp_g) | ||||
| g = self.cast(g, mstype.float32) | g = self.cast(g, mstype.float32) | ||||
| g = self.mul(g, matrix_max) | |||||
| new_grads = new_grads + (g,) | new_grads = new_grads + (g,) | ||||
| # process bert_embedding_postprocessor.layernorm | # process bert_embedding_postprocessor.layernorm | ||||
| grad_idx = 3 | grad_idx = 3 | ||||
| @@ -362,15 +305,14 @@ class THOR(Optimizer): | |||||
| matrix_idx = 6 * i + offset_idx + 3 | matrix_idx = 6 * i + offset_idx + 3 | ||||
| temp_a = self.matrix_A[matrix_idx] | temp_a = self.matrix_A[matrix_idx] | ||||
| temp_g = self.matrix_G[matrix_idx] | temp_g = self.matrix_G[matrix_idx] | ||||
| matrix_max = self.matrix_max_inv[matrix_idx] | |||||
| temp_a = F.depend(temp_a, g) | |||||
| temp_g = F.depend(temp_g, g) | |||||
| temp_a = self.cast(temp_a, mstype.float16) | temp_a = self.cast(temp_a, mstype.float16) | ||||
| temp_g = self.cast(temp_g, mstype.float16) | temp_g = self.cast(temp_g, mstype.float16) | ||||
| g = self.cast(g, mstype.float16) | g = self.cast(g, mstype.float16) | ||||
| g = self.matmul(temp_g, g) | g = self.matmul(temp_g, g) | ||||
| g = self.matmul(g, temp_a) | g = self.matmul(g, temp_a) | ||||
| g = self.cast(g, mstype.float32) | g = self.cast(g, mstype.float32) | ||||
| g = self.mul(g, matrix_max) | |||||
| new_grads = new_grads + (g,) | new_grads = new_grads + (g,) | ||||
| new_grads = new_grads + (gradients[grad_idx + 1],) | new_grads = new_grads + (gradients[grad_idx + 1],) | ||||
| @@ -381,15 +323,14 @@ class THOR(Optimizer): | |||||
| pooler_bias = gradients[pooler_layer_idx + 1] | pooler_bias = gradients[pooler_layer_idx + 1] | ||||
| temp_a = self.matrix_A[matrix_idx] | temp_a = self.matrix_A[matrix_idx] | ||||
| temp_g = self.matrix_G[matrix_idx] | temp_g = self.matrix_G[matrix_idx] | ||||
| matrix_max = self.matrix_max_inv[matrix_idx] | |||||
| temp_a = F.depend(temp_a, g) | |||||
| temp_g = F.depend(temp_g, g) | |||||
| temp_a = self.cast(temp_a, mstype.float16) | temp_a = self.cast(temp_a, mstype.float16) | ||||
| temp_g = self.cast(temp_g, mstype.float16) | temp_g = self.cast(temp_g, mstype.float16) | ||||
| g = self.cast(g, mstype.float16) | g = self.cast(g, mstype.float16) | ||||
| g = self.matmul(temp_g, g) | g = self.matmul(temp_g, g) | ||||
| g = self.matmul(g, temp_a) | g = self.matmul(g, temp_a) | ||||
| g = self.cast(g, mstype.float32) | g = self.cast(g, mstype.float32) | ||||
| g = self.mul(g, matrix_max) | |||||
| new_grads = new_grads + (g, pooler_bias) | new_grads = new_grads + (g, pooler_bias) | ||||
| # for cls1 fc layer: mlm | # for cls1 fc layer: mlm | ||||
| @@ -399,15 +340,14 @@ class THOR(Optimizer): | |||||
| mlm_bias = gradients[mlm_fc_idx + 1] | mlm_bias = gradients[mlm_fc_idx + 1] | ||||
| temp_a = self.matrix_A[matrix_idx] | temp_a = self.matrix_A[matrix_idx] | ||||
| temp_g = self.matrix_G[matrix_idx] | temp_g = self.matrix_G[matrix_idx] | ||||
| matrix_max = self.matrix_max_inv[matrix_idx] | |||||
| temp_a = F.depend(temp_a, g) | |||||
| temp_g = F.depend(temp_g, g) | |||||
| temp_a = self.cast(temp_a, mstype.float16) | temp_a = self.cast(temp_a, mstype.float16) | ||||
| temp_g = self.cast(temp_g, mstype.float16) | temp_g = self.cast(temp_g, mstype.float16) | ||||
| g = self.cast(g, mstype.float16) | g = self.cast(g, mstype.float16) | ||||
| g = self.matmul(temp_g, g) | g = self.matmul(temp_g, g) | ||||
| g = self.matmul(g, temp_a) | g = self.matmul(g, temp_a) | ||||
| g = self.cast(g, mstype.float32) | g = self.cast(g, mstype.float32) | ||||
| g = self.mul(g, matrix_max) | |||||
| # add bert.cls1.output_bias grad | # add bert.cls1.output_bias grad | ||||
| new_grads = new_grads + (gradients[mlm_fc_idx - 1],) | new_grads = new_grads + (gradients[mlm_fc_idx - 1],) | ||||
| new_grads = new_grads + (g, mlm_bias) | new_grads = new_grads + (g, mlm_bias) | ||||
| @@ -415,6 +355,7 @@ class THOR(Optimizer): | |||||
| begin_idx = mlm_fc_idx + 2 | begin_idx = mlm_fc_idx + 2 | ||||
| end_idx = mlm_fc_idx + 4 | end_idx = mlm_fc_idx + 4 | ||||
| new_grads = new_grads + gradients[begin_idx: end_idx] | new_grads = new_grads + gradients[begin_idx: end_idx] | ||||
| lenth = len(gradients) | lenth = len(gradients) | ||||
| new_grads = new_grads + gradients[lenth - 2: lenth] | new_grads = new_grads + gradients[lenth - 2: lenth] | ||||
| gradients = new_grads | gradients = new_grads | ||||
| @@ -14,7 +14,6 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """thor_layer""" | """thor_layer""" | ||||
| import numpy as np | import numpy as np | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindspore._checkparam import check_bool, check_int_positive | from mindspore._checkparam import check_bool, check_int_positive | ||||
| from mindspore.common.initializer import TruncatedNormal, initializer | from mindspore.common.initializer import TruncatedNormal, initializer | ||||
| @@ -24,7 +23,6 @@ from mindspore.nn.cell import Cell | |||||
| from mindspore.nn.layer.activation import get_activation | from mindspore.nn.layer.activation import get_activation | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| class Embedding_Thor(Cell): | class Embedding_Thor(Cell): | ||||
| """ | """ | ||||
| A embeddings lookup table with a fixed dictionary and size. | A embeddings lookup table with a fixed dictionary and size. | ||||
| @@ -37,7 +35,6 @@ class Embedding_Thor(Cell): | |||||
| use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. | use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. | ||||
| initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. | initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. | ||||
| """ | """ | ||||
| def __init__(self, | def __init__(self, | ||||
| vocab_size, | vocab_size, | ||||
| embedding_size, | embedding_size, | ||||
| @@ -45,11 +42,10 @@ class Embedding_Thor(Cell): | |||||
| use_one_hot_embeddings=False, | use_one_hot_embeddings=False, | ||||
| initializer_range=0.02, | initializer_range=0.02, | ||||
| name='embedding_table', | name='embedding_table', | ||||
| is_expand=False, | |||||
| batch_size=12, | batch_size=12, | ||||
| damping=0.03, | damping=0.03, | ||||
| loss_scale=1, | loss_scale=1, | ||||
| frequency=10, | |||||
| frequency=100, | |||||
| ): | ): | ||||
| super(Embedding_Thor, self).__init__() | super(Embedding_Thor, self).__init__() | ||||
| self.vocab_size = vocab_size | self.vocab_size = vocab_size | ||||
| @@ -59,7 +55,6 @@ class Embedding_Thor(Cell): | |||||
| [vocab_size, embedding_size]), | [vocab_size, embedding_size]), | ||||
| name=name) | name=name) | ||||
| self.thor = True | self.thor = True | ||||
| self.is_expand = is_expand | |||||
| self.expand = P.ExpandDims() | self.expand = P.ExpandDims() | ||||
| self.shape_flat = (-1,) | self.shape_flat = (-1,) | ||||
| self.gather = P.GatherV2() | self.gather = P.GatherV2() | ||||
| @@ -71,13 +66,11 @@ class Embedding_Thor(Cell): | |||||
| self.em_shape = tuple(embedding_shape) | self.em_shape = tuple(embedding_shape) | ||||
| self.shape = P.Shape() | self.shape = P.Shape() | ||||
| self.loss_scale = Tensor(1 / loss_scale, mstype.float16) | self.loss_scale = Tensor(1 / loss_scale, mstype.float16) | ||||
| self.matrix_A_inv = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)), name='matrix_A_inv', | |||||
| requires_grad=False) | |||||
| self.matrix_A_inv = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float16)), | |||||
| name='matrix_A_inv', requires_grad=False) | |||||
| self.matrix_G_inv = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float16)), | self.matrix_G_inv = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float16)), | ||||
| name="matrix_G_inv", requires_grad=False) | name="matrix_G_inv", requires_grad=False) | ||||
| self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False) | |||||
| self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False) | |||||
| self.fused_abs_max = P.CusFusedAbsMax1() | |||||
| self.fake_G = Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float16)) | self.fake_G = Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float16)) | ||||
| self.dampingA = Tensor(np.ones([vocab_size]).astype(np.float32)) | self.dampingA = Tensor(np.ones([vocab_size]).astype(np.float32)) | ||||
| self.dampingG = Tensor(np.identity(embedding_size), mstype.float32) | self.dampingG = Tensor(np.identity(embedding_size), mstype.float32) | ||||
| @@ -117,9 +110,6 @@ class Embedding_Thor(Cell): | |||||
| matrix_G = matrix_G + damping * dampingG | matrix_G = matrix_G + damping * dampingG | ||||
| matrix_G_inv = self.cholesky(matrix_G) | matrix_G_inv = self.cholesky(matrix_G) | ||||
| matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv) | matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv) | ||||
| matrix_G_inv_max = self.fused_abs_max(matrix_G_inv) | |||||
| matrix_G_inv_max = self.fused_abs_max(matrix_G_inv_max) | |||||
| self.G_inv_max = matrix_G_inv_max | |||||
| matrix_G_inv = self.matrix_combine(matrix_G_inv) | matrix_G_inv = self.matrix_combine(matrix_G_inv) | ||||
| matrix_G_inv = self.cast(matrix_G_inv, mstype.float16) | matrix_G_inv = self.cast(matrix_G_inv, mstype.float16) | ||||
| self.matrix_G_inv = matrix_G_inv | self.matrix_G_inv = matrix_G_inv | ||||
| @@ -127,8 +117,6 @@ class Embedding_Thor(Cell): | |||||
| def construct(self, input_ids): | def construct(self, input_ids): | ||||
| """construct of Embedding_Thor""" | """construct of Embedding_Thor""" | ||||
| if self.is_expand: | |||||
| input_ids = self.expand(input_ids, -1) | |||||
| flat_ids = self.reshape(input_ids, self.shape_flat) | flat_ids = self.reshape(input_ids, self.shape_flat) | ||||
| if self.use_one_hot_embeddings: | if self.use_one_hot_embeddings: | ||||
| one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) | one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) | ||||
| @@ -146,6 +134,7 @@ class Embedding_Thor(Cell): | |||||
| dampingA = self.cast(self.dampingA, mstype.float32) | dampingA = self.cast(self.dampingA, mstype.float32) | ||||
| matrix_A = matrix_A + damping * dampingA | matrix_A = matrix_A + damping * dampingA | ||||
| matrix_A_inv = self.inv(matrix_A) | matrix_A_inv = self.inv(matrix_A) | ||||
| matrix_A_inv = self.cast(matrix_A_inv, mstype.float16) | |||||
| self.matrix_A_inv = matrix_A_inv | self.matrix_A_inv = matrix_A_inv | ||||
| self.matrix_G_inv = self.fake_G | self.matrix_G_inv = self.fake_G | ||||
| output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) | output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) | ||||
| @@ -156,11 +145,9 @@ class Embedding_Thor(Cell): | |||||
| output = self.reshape(output_for_reshape, self.em_shape) | output = self.reshape(output_for_reshape, self.em_shape) | ||||
| return output, self.embedding_table | return output, self.embedding_table | ||||
| class Dense_Thor(Cell): | class Dense_Thor(Cell): | ||||
| """Dense_Thor""" | """Dense_Thor""" | ||||
| # @cell_attr_register(attrs=['has_bias', 'activation', 'in_channels', 'out_channels']) | |||||
| def __init__(self, | def __init__(self, | ||||
| in_channels, | in_channels, | ||||
| out_channels, | out_channels, | ||||
| @@ -168,7 +155,7 @@ class Dense_Thor(Cell): | |||||
| bias_init='zeros', | bias_init='zeros', | ||||
| damping=0.03, | damping=0.03, | ||||
| loss_scale=1, | loss_scale=1, | ||||
| frequency=10, | |||||
| frequency=100, | |||||
| has_bias=False, | has_bias=False, | ||||
| activation=None, | activation=None, | ||||
| batch_size=12): | batch_size=12): | ||||
| @@ -200,9 +187,6 @@ class Dense_Thor(Cell): | |||||
| name='matrix_A_inv', requires_grad=False) | name='matrix_A_inv', requires_grad=False) | ||||
| self.matrix_G_inv = Parameter(Tensor(np.zeros([out_channels, out_channels]).astype(np.float16)), | self.matrix_G_inv = Parameter(Tensor(np.zeros([out_channels, out_channels]).astype(np.float16)), | ||||
| name="matrix_G_inv", requires_grad=False) | name="matrix_G_inv", requires_grad=False) | ||||
| self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False) | |||||
| self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False) | |||||
| self.fused_abs_max = P.CusFusedAbsMax1() | |||||
| self.fake_G = Tensor(np.zeros([out_channels, out_channels]).astype(np.float16)) | self.fake_G = Tensor(np.zeros([out_channels, out_channels]).astype(np.float16)) | ||||
| self.matmul = P.MatMul(transpose_b=True) | self.matmul = P.MatMul(transpose_b=True) | ||||
| @@ -250,9 +234,6 @@ class Dense_Thor(Cell): | |||||
| matrix_G = matrix_G + damping * dampingG | matrix_G = matrix_G + damping * dampingG | ||||
| matrix_G_inv = self.cholesky(matrix_G) | matrix_G_inv = self.cholesky(matrix_G) | ||||
| matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv) | matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv) | ||||
| matrix_G_inv_max = self.fused_abs_max(matrix_G_inv) | |||||
| matrix_G_inv_max = self.fused_abs_max(matrix_G_inv_max) | |||||
| self.G_inv_max = matrix_G_inv_max | |||||
| matrix_G_inv = self.matrix_combine(matrix_G_inv) | matrix_G_inv = self.matrix_combine(matrix_G_inv) | ||||
| matrix_G_inv = self.cast(matrix_G_inv, mstype.float16) | matrix_G_inv = self.cast(matrix_G_inv, mstype.float16) | ||||
| self.matrix_G_inv = matrix_G_inv | self.matrix_G_inv = matrix_G_inv | ||||
| @@ -265,7 +246,6 @@ class Dense_Thor(Cell): | |||||
| shape = self.shape(x) | shape = self.shape(x) | ||||
| normalizer = self.cast(shape[0], mstype.float32) | normalizer = self.cast(shape[0], mstype.float32) | ||||
| matrix_A = self.mul(inputs, 1.0 / normalizer) | matrix_A = self.mul(inputs, 1.0 / normalizer) | ||||
| damping_step = self.gather(self.damping, self.cov_step, self.axis) | damping_step = self.gather(self.damping, self.cov_step, self.axis) | ||||
| damping_step = self.cast(damping_step, mstype.float32) | damping_step = self.cast(damping_step, mstype.float32) | ||||
| damping = self.sqrt(damping_step) | damping = self.sqrt(damping_step) | ||||
| @@ -273,9 +253,6 @@ class Dense_Thor(Cell): | |||||
| matrix_A = matrix_A + damping * dampingA | matrix_A = matrix_A + damping * dampingA | ||||
| matrix_A_inv = self.cholesky(matrix_A) | matrix_A_inv = self.cholesky(matrix_A) | ||||
| matrix_A_inv = self.vector_matmul(matrix_A_inv, matrix_A_inv) | matrix_A_inv = self.vector_matmul(matrix_A_inv, matrix_A_inv) | ||||
| matrix_A_inv_max = self.fused_abs_max(matrix_A_inv) | |||||
| matrix_A_inv_max = self.fused_abs_max(matrix_A_inv_max) | |||||
| self.A_inv_max = matrix_A_inv_max | |||||
| matrix_A_inv = self.matrix_combine(matrix_A_inv) | matrix_A_inv = self.matrix_combine(matrix_A_inv) | ||||
| matrix_A_inv = self.cast(matrix_A_inv, mstype.float16) | matrix_A_inv = self.cast(matrix_A_inv, mstype.float16) | ||||
| self.matrix_A_inv = matrix_A_inv | self.matrix_A_inv = matrix_A_inv | ||||