Merge pull request !4460 from zongha/mastertags/v0.7.0-beta
| @@ -11,14 +11,14 @@ This is an example of training bert by second-order optimizer THOR. THOR is a no | |||
| ## Running the Example | |||
| ### Pre-Training | |||
| - Set options in `config.py`, including lossscale, optimizer and network. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#tfrecord) for more information about dataset and the json schema file. | |||
| - Set options in `config.py`, including optimizer and network. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#tfrecord) for more information about dataset and the json schema file. | |||
| - Run `run_standalone_pretrain.sh` for non-distributed pre-training of BERT-base and BERT-NEZHA model. | |||
| - Run `run_standalone_pretrain.sh` for non-distributed pre-training of BERT-base, BERT-NEZHA and BERT-large model. | |||
| ``` bash | |||
| sh scripts/run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR | |||
| ``` | |||
| - Run `run_distribute_pretrain.sh` for distributed pre-training of BERT-base and BERT-NEZHA model. | |||
| - Run `run_distribute_pretrain.sh` for distributed pre-training of BERT-base, BERT-NEZHA and BERT-large model. | |||
| ``` bash | |||
| sh scripts/run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR RANK_TABLE_FILE | |||
| @@ -30,7 +30,7 @@ This is an example of training bert by second-order optimizer THOR. THOR is a no | |||
| usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [--device_id N] | |||
| [--enable_save_ckpt ENABLE_SAVE_CKPT] | |||
| [--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE] | |||
| [--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] [--checkpoint_path CHECKPOINT_PATH] | |||
| [--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] [--save_checkpoint_path CHECKPOINT_PATH] | |||
| [--save_checkpoint_steps N] [--save_checkpoint_num N] | |||
| [--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR] | |||
| @@ -44,7 +44,7 @@ options: | |||
| --do_shuffle enable shuffle: "true" | "false", default is "true" | |||
| --enable_data_sink enable data sink: "true" | "false", default is "true" | |||
| --data_sink_steps set data sink steps: N, default is 1 | |||
| --checkpoint_path path to save checkpoint files: PATH, default is "" | |||
| --save_checkpoint_path path to save checkpoint files: PATH, default is "" | |||
| --save_checkpoint_steps steps for saving checkpoint files: N, default is 1000 | |||
| --save_checkpoint_num number for saving checkpoint files: N, default is 1 | |||
| --data_dir path to dataset directory: PATH, default is "" | |||
| @@ -55,7 +55,7 @@ It contains of parameters of BERT model and options for training, which is set i | |||
| ### Options: | |||
| ``` | |||
| config.py: | |||
| bert_network version of BERT model: base | nezha, default is base | |||
| bert_network version of BERT model: base | nezha | large, default is large | |||
| optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum | Thor, default is "Thor" | |||
| ``` | |||
| @@ -63,7 +63,7 @@ config.py: | |||
| ### Parameters: | |||
| ``` | |||
| Parameters for dataset and network (Pre-Training/Evaluation): | |||
| batch_size batch size of input dataset: N, default is 8 | |||
| batch_size batch size of input dataset: N, default is 12 | |||
| seq_length length of input sequence: N, default is 128 | |||
| vocab_size size of each embedding vector: N, must be consistant with the dataset you use. Default is 21136 | |||
| hidden_size size of bert encoder layers: N, default is 768 | |||
| @@ -87,7 +87,7 @@ Parameters for optimizer: | |||
| momentum momentum for the moving average: Q | |||
| weight_decay weight decay: Q | |||
| loss_scale loss scale: N | |||
| frequency the step interval to update second-order information matrix: N, default is 10 | |||
| batch_size batch size of input dataset: N, default is 8 | |||
| frequency the step interval to update second-order information matrix: N, default is 100 | |||
| batch_size batch size of input dataset: N, default is 12 | |||
| ``` | |||
| @@ -19,7 +19,6 @@ python run_pretrain.py | |||
| import argparse | |||
| import os | |||
| import numpy | |||
| from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell | |||
| from src.bert_net_config import bert_net_cfg | |||
| @@ -27,10 +26,8 @@ from src.config import cfg | |||
| from src.dataset import create_bert_dataset | |||
| from src.lr_generator import get_bert_lr, get_bert_damping | |||
| from src.model_thor import Model | |||
| # from src.thor_for_bert import THOR | |||
| from src.thor_for_bert_arg import THOR | |||
| from src.utils import LossCallBack, BertLearningRate | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.communication.management as D | |||
| from mindspore import context | |||
| @@ -69,8 +66,8 @@ def run_pretrain(): | |||
| parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") | |||
| args_opt = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id, | |||
| save_graphs=True) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, | |||
| device_id=args_opt.device_id, save_graphs=False) | |||
| context.set_context(reserve_class_name_in_scope=False) | |||
| context.set_context(variable_memory_max_size="30GB") | |||
| ckpt_save_dir = args_opt.save_checkpoint_path | |||
| @@ -165,15 +162,13 @@ def run_pretrain(): | |||
| optimizer = THOR(filter(lambda x: x.requires_grad, net_with_loss.get_parameters()), lr, cfg.Thor.momentum, | |||
| filter(lambda x: 'matrix_A' in x.name, net_with_loss.get_parameters()), | |||
| filter(lambda x: 'matrix_G' in x.name, net_with_loss.get_parameters()), | |||
| filter(lambda x: 'A_inv_max' in x.name, net_with_loss.get_parameters()), | |||
| filter(lambda x: 'G_inv_max' in x.name, net_with_loss.get_parameters()), | |||
| cfg.Thor.weight_decay, cfg.Thor.loss_scale, bert_net_cfg.num_hidden_layers, | |||
| bert_net_cfg.batch_size, damping) | |||
| else: | |||
| raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]". | |||
| raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]". | |||
| format(cfg.optimizer)) | |||
| callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()] | |||
| if args_opt.enable_save_ckpt == "true": | |||
| if args_opt.enable_save_ckpt == "true" and rank == 0: | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps, | |||
| keep_checkpoint_max=args_opt.save_checkpoint_num) | |||
| ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', directory=ckpt_save_dir, config=config_ck) | |||
| @@ -37,25 +37,26 @@ do | |||
| rm -rf LOG$i | |||
| mkdir ./LOG$i | |||
| cp *.py ./LOG$i | |||
| cp -r src ./LOG$i | |||
| cp ../*.py ./LOG$i | |||
| cp -r ../src ./LOG$i | |||
| cd ./LOG$i || exit | |||
| echo "start training for rank $i, device $DEVICE_ID" | |||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | |||
| env > env.log | |||
| python ../run_pretrain.py \ | |||
| python run_pretrain.py \ | |||
| --distribute="true" \ | |||
| --epoch_size=$EPOCH_SIZE \ | |||
| --device_id=$DEVICE_ID \ | |||
| --device_num=$RANK_SIZE \ | |||
| --enable_save_ckpt="true" \ | |||
| --enable_lossscale="false" \ | |||
| --do_shuffle="true" \ | |||
| --do_shuffle="false" \ | |||
| --enable_data_sink="true" \ | |||
| --data_sink_steps=1000 \ | |||
| --load_checkpoint_path="" \ | |||
| --save_checkpoint_steps=5000 \ | |||
| --save_checkpoint_path='./' \ | |||
| --save_checkpoint_steps=1000 \ | |||
| --save_checkpoint_num=30 \ | |||
| --data_dir=$DATA_DIR \ | |||
| --schema_dir=$SCHEMA_DIR > log.txt 2>&1 & | |||
| cd ../ | |||
| done | |||
| done | |||
| @@ -20,27 +20,39 @@ echo "bash run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR" | |||
| echo "for example: bash run_standalone_pretrain.sh 0 40 /path/zh-wiki/ /path/Schema.json" | |||
| echo "==============================================================================================================" | |||
| DEVICE_ID=$1 | |||
| EPOCH_SIZE=$2 | |||
| DATA_DIR=$3 | |||
| SCHEMA_DIR=$4 | |||
| mkdir -p ms_log | |||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||
| CUR_DIR=`pwd` | |||
| export GLOG_log_dir=${CUR_DIR}/ms_log | |||
| export GLOG_logtostderr=0 | |||
| python ${PROJECT_DIR}/../run_pretrain.py \ | |||
| --distribute="false" \ | |||
| --epoch_size=$EPOCH_SIZE \ | |||
| --device_id=$DEVICE_ID \ | |||
| --enable_save_ckpt="true" \ | |||
| --enable_lossscale="true" \ | |||
| --do_shuffle="true" \ | |||
| --enable_data_sink="true" \ | |||
| --data_sink_steps=1 \ | |||
| --load_checkpoint_path="" \ | |||
| --save_checkpoint_steps=10000 \ | |||
| --save_checkpoint_num=1 \ | |||
| --data_dir=$DATA_DIR \ | |||
| --schema_dir=$SCHEMA_DIR > log.txt 2>&1 & | |||
| ulimit -u unlimited | |||
| export DEVICE_ID=$1 | |||
| export RANK_SIZE=1 | |||
| if [ -d "LOG" ]; | |||
| then | |||
| rm -rf ./LOG | |||
| fi | |||
| mkdir ./LOG | |||
| cp ../*.py ./LOG | |||
| cp -r ../src ./LOG | |||
| cd ./LOG || exit | |||
| echo "start training for device $DEVICE_ID" | |||
| env > env.log | |||
| python run_pretrain.py \ | |||
| --distribute="false" \ | |||
| --epoch_size=$EPOCH_SIZE \ | |||
| --device_id=$DEVICE_ID \ | |||
| --device_num=$RANK_SIZE \ | |||
| --enable_save_ckpt="true" \ | |||
| --enable_lossscale="false" \ | |||
| --do_shuffle="false" \ | |||
| --enable_data_sink="true" \ | |||
| --data_sink_steps=1000 \ | |||
| --load_checkpoint_path="" \ | |||
| --save_checkpoint_path='./' \ | |||
| --save_checkpoint_steps=5000 \ | |||
| --save_checkpoint_num=20 \ | |||
| --data_dir=$DATA_DIR \ | |||
| --schema_dir=$SCHEMA_DIR > log.txt 2>&1 & | |||
| cd ../ | |||
| @@ -35,6 +35,8 @@ from .thor_layer import Dense_Thor | |||
| damping = get_bert_damping() | |||
| loss_scale = cfg.Thor.loss_scale | |||
| frequency = cfg.Thor.frequency | |||
| batch_size = cfg.Thor.batch_size | |||
| GRADIENT_CLIP_TYPE = 1 | |||
| GRADIENT_CLIP_VALUE = 1.0 | |||
| @@ -91,9 +93,9 @@ class GetMaskedLMOutput(nn.Cell): | |||
| bias_init='zeros', | |||
| damping=damping, | |||
| loss_scale=loss_scale, | |||
| frequency=1, | |||
| frequency=frequency, | |||
| activation=config.hidden_act, | |||
| batch_size=config.batch_size).to_float(config.compute_type) | |||
| batch_size=batch_size).to_float(config.compute_type) | |||
| self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type) | |||
| self.output_bias = Parameter( | |||
| initializer( | |||
| @@ -34,6 +34,7 @@ from .thor_layer import Dense_Thor, Embedding_Thor | |||
| damping = get_bert_damping() | |||
| loss_scale = cfg.Thor.loss_scale | |||
| frequency = cfg.Thor.frequency | |||
| batch_size = cfg.Thor.batch_size | |||
| @@ -200,11 +201,10 @@ class EmbeddingPostprocessor(nn.Cell): | |||
| use_one_hot_embeddings=use_one_hot_embeddings, | |||
| initializer_range=initializer_range, | |||
| name='embedding_table', | |||
| is_expand=False, | |||
| batch_size=batch_size, | |||
| damping=damping, | |||
| loss_scale=loss_scale, | |||
| frequency=1) | |||
| frequency=frequency) | |||
| self.shape_flat = (-1,) | |||
| self.one_hot = P.OneHot() | |||
| self.on_value = Tensor(1.0, mstype.float32) | |||
| @@ -225,11 +225,10 @@ class EmbeddingPostprocessor(nn.Cell): | |||
| use_one_hot_embeddings=use_one_hot_embeddings, | |||
| initializer_range=initializer_range, | |||
| name='full_position_embeddings', | |||
| is_expand=False, | |||
| batch_size=batch_size, | |||
| damping=damping, | |||
| loss_scale=loss_scale, | |||
| frequency=1) | |||
| frequency=frequency) | |||
| self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32)) | |||
| self.layernorm = nn.LayerNorm((embedding_size,)) | |||
| @@ -274,7 +273,7 @@ class BertOutput(nn.Cell): | |||
| bias_init='zeros', | |||
| damping=damping, | |||
| loss_scale=loss_scale, | |||
| frequency=1, | |||
| frequency=frequency, | |||
| activation=None, | |||
| batch_size=batch_size).to_float(compute_type) | |||
| self.dropout = nn.Dropout(1 - dropout_prob) | |||
| @@ -488,7 +487,7 @@ class BertAttention(nn.Cell): | |||
| bias_init='zeros', | |||
| damping=damping, | |||
| loss_scale=loss_scale, | |||
| frequency=1, | |||
| frequency=frequency, | |||
| activation=query_act, | |||
| batch_size=batch_size).to_float(compute_type) | |||
| self.key_layer = Dense_Thor(in_channels=to_tensor_width, | |||
| @@ -498,7 +497,7 @@ class BertAttention(nn.Cell): | |||
| bias_init='zeros', | |||
| damping=damping, | |||
| loss_scale=loss_scale, | |||
| frequency=1, | |||
| frequency=frequency, | |||
| activation=key_act, | |||
| batch_size=batch_size).to_float(compute_type) | |||
| self.value_layer = Dense_Thor(in_channels=to_tensor_width, | |||
| @@ -508,7 +507,7 @@ class BertAttention(nn.Cell): | |||
| bias_init='zeros', | |||
| damping=damping, | |||
| loss_scale=loss_scale, | |||
| frequency=1, | |||
| frequency=frequency, | |||
| activation=value_act, | |||
| batch_size=batch_size).to_float(compute_type) | |||
| self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head) | |||
| @@ -764,7 +763,7 @@ class BertEncoderCell(nn.Cell): | |||
| bias_init='zeros', | |||
| damping=damping, | |||
| loss_scale=loss_scale, | |||
| frequency=1, | |||
| frequency=frequency, | |||
| activation=hidden_act, | |||
| batch_size=batch_size).to_float(compute_type) | |||
| self.output = BertOutput(in_channels=intermediate_size, | |||
| @@ -945,11 +944,10 @@ class BertModel(nn.Cell): | |||
| use_one_hot_embeddings=use_one_hot_embeddings, | |||
| initializer_range=config.initializer_range, | |||
| name='embedding_table', | |||
| is_expand=True, | |||
| batch_size=batch_size, | |||
| damping=damping, | |||
| loss_scale=loss_scale, | |||
| frequency=1) | |||
| frequency=frequency) | |||
| self.bert_embedding_postprocessor = EmbeddingPostprocessor( | |||
| embedding_size=self.embedding_size, | |||
| embedding_shape=output_embedding_shape, | |||
| @@ -991,7 +989,7 @@ class BertModel(nn.Cell): | |||
| bias_init='zeros', | |||
| damping=damping, | |||
| loss_scale=loss_scale, | |||
| frequency=1, | |||
| frequency=frequency, | |||
| activation="tanh", | |||
| batch_size=batch_size).to_float(config.compute_type) | |||
| self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config) | |||
| @@ -19,9 +19,6 @@ from easydict import EasyDict as edict | |||
| cfg = edict({ | |||
| 'bert_network': 'large', | |||
| 'loss_scale_value': 65536, | |||
| 'scale_factor': 2, | |||
| 'scale_window': 1000, | |||
| 'optimizer': 'Thor', | |||
| 'AdamWeightDecay': edict({ | |||
| 'learning_rate': 3e-5, | |||
| @@ -49,7 +46,7 @@ cfg = edict({ | |||
| 'momentum': 0.9, | |||
| 'weight_decay': 5e-4, | |||
| 'loss_scale': 1, | |||
| 'frequency': 10, | |||
| 'batch_size': 8, | |||
| 'frequency': 100, | |||
| 'batch_size': 12, | |||
| }), | |||
| }) | |||
| @@ -16,7 +16,6 @@ | |||
| Data operations, will be used in run_pretrain.py | |||
| """ | |||
| import os | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset.engine.datasets as de | |||
| import mindspore.dataset.transforms.c_transforms as C | |||
| @@ -37,7 +36,7 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None, | |||
| columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", | |||
| "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"], | |||
| shuffle=de.Shuffle.FILES if do_shuffle == "true" else False, | |||
| num_shards=device_num, shard_id=rank, shard_equal_rows=True) | |||
| num_shards=device_num, shard_id=rank, shard_equal_rows=False) | |||
| ori_dataset_size = ds.get_dataset_size() | |||
| print('origin dataset size: ', ori_dataset_size) | |||
| type_cast_op = C.TypeCast(mstype.int32) | |||
| @@ -80,7 +80,7 @@ def _tensors_cast_datatype(datatype, grad): | |||
| return F.cast(grad, datatype) | |||
| class DistributedGradReducerThor1(Cell): | |||
| class DistributedGradReducerThor(Cell): | |||
| """ | |||
| A distributed optimizer. | |||
| @@ -154,7 +154,7 @@ class DistributedGradReducerThor1(Cell): | |||
| """ | |||
| def __init__(self, parameters, group, mean=True, degree=None): | |||
| super(DistributedGradReducerThor1, self).__init__(auto_prefix=False) | |||
| super(DistributedGradReducerThor, self).__init__(auto_prefix=False) | |||
| self.hyper_map = C.HyperMap() | |||
| self.mul = P.Mul() | |||
| if degree is None: | |||
| @@ -168,7 +168,7 @@ class DistributedGradReducerThor1(Cell): | |||
| _init_optimizer_allreduce(group) | |||
| def construct(self, grads): | |||
| """construct of DistributedGradReducerThor1""" | |||
| """construct of DistributedGradReducerThor""" | |||
| # In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the | |||
| # result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce, | |||
| # and cast back after the operation. | |||
| @@ -58,7 +58,7 @@ def get_poly_lr(global_step, lr_init, lr_end, lr_max, warmup_steps, total_steps, | |||
| # bert kfac hyperparam setting | |||
| def get_bert_lr(): | |||
| learning_rate = Tensor( | |||
| get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=4e-4, warmup_steps=0, total_steps=30000, | |||
| get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=3.1e-3, warmup_steps=0, total_steps=30000, | |||
| poly_power=1)) | |||
| return learning_rate | |||
| @@ -46,9 +46,8 @@ def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): | |||
| class THOR(Optimizer): | |||
| """THOR""" | |||
| def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0, | |||
| loss_scale=1.0, num_hidden_layers=24, batch_size=12, damping=0.03, frequency=10, | |||
| def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, weight_decay=0.0, | |||
| loss_scale=1.0, num_hidden_layers=24, batch_size=12, damping=0.03, | |||
| decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()): | |||
| super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale) | |||
| if isinstance(momentum, float) and momentum < 0.0: | |||
| @@ -60,8 +59,6 @@ class THOR(Optimizer): | |||
| self.opt = P.ApplyMomentum() | |||
| self.matrix_A = ParameterTuple(matrix_A) | |||
| self.matrix_G = ParameterTuple(matrix_G) | |||
| self.A_inv_max = ParameterTuple(A_inv_max) | |||
| self.G_inv_max = ParameterTuple(G_inv_max) | |||
| self.matmul = P.MatMul() | |||
| self.transpose = P.Transpose() | |||
| self.shape = P.Shape() | |||
| @@ -70,16 +67,8 @@ class THOR(Optimizer): | |||
| self.gather = P.GatherV2() | |||
| self.matrix_A_inv = () | |||
| self.matrix_G_inv = () | |||
| self.matrix_max_inv = () | |||
| self.num_hidden_layers = num_hidden_layers | |||
| fc_layer_num = num_hidden_layers * 6 + 5 | |||
| for i in range(fc_layer_num): | |||
| self.matrix_max_inv = self.matrix_max_inv + ( | |||
| Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),) | |||
| self.log = P.Log() | |||
| self.exp = P.Exp() | |||
| self.sqrt = P.Sqrt() | |||
| self.matrix_max_inv = ParameterTuple(self.matrix_max_inv) | |||
| self.assign = P.Assign() | |||
| self.cast = P.Cast() | |||
| self.thor = True | |||
| @@ -90,7 +79,6 @@ class THOR(Optimizer): | |||
| self.inv = P.Inv() | |||
| self.batch_size = batch_size | |||
| self.damping = damping | |||
| self.freq = Tensor(frequency, mstype.int32) | |||
| self.one = Tensor(1, mstype.int32) | |||
| self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False) | |||
| @@ -106,26 +94,20 @@ class THOR(Optimizer): | |||
| g = gradients[em_idx] | |||
| matrix_idx = em_idx | |||
| temp_a_ori = self.matrix_A[matrix_idx] | |||
| temp_a = self.expand(temp_a_ori, 1) | |||
| temp_g = self.matrix_G[matrix_idx] | |||
| G_max = self.G_inv_max[matrix_idx] | |||
| temp_g = self.cast(temp_g, mstype.float32) | |||
| matrix_G_inv_max = self.log(G_max) | |||
| matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) | |||
| matrix_G_inv_max = self.exp(matrix_G_inv_max) | |||
| temp_g = self.mul(temp_g, matrix_G_inv_max) | |||
| g = self.mul(temp_a, g) | |||
| g = self.cast(g, mstype.float16) | |||
| temp_a_ori = F.depend(temp_a_ori, g) | |||
| temp_g = F.depend(temp_g, g) | |||
| temp_a = self.expand(temp_a_ori, 1) | |||
| temp_a = self.cast(temp_a, mstype.float16) | |||
| temp_g = self.cast(temp_g, mstype.float16) | |||
| g = self.cast(g, mstype.float16) | |||
| g = self.mul(temp_a, g) | |||
| g = self.matmul(g, temp_g) | |||
| g = self.cast(g, mstype.float32) | |||
| g = self.mul(g, G_max) | |||
| fake_A = self.assign(self.matrix_A[matrix_idx], temp_a_ori) | |||
| fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | |||
| fake_max = self.assign(self.matrix_max_inv[matrix_idx], G_max) | |||
| g = F.depend(g, fake_A) | |||
| g = F.depend(g, fake_G) | |||
| g = F.depend(g, fake_max) | |||
| new_grads = new_grads + (g,) | |||
| # process bert_embedding_postprocessor.layernorm | |||
| grad_idx = 3 | |||
| @@ -180,32 +162,18 @@ class THOR(Optimizer): | |||
| matrix_idx = 6 * i + offset_idx + 3 | |||
| temp_a = self.matrix_A[matrix_idx] | |||
| temp_g = self.matrix_G[matrix_idx] | |||
| temp_a = self.cast(temp_a, mstype.float32) | |||
| temp_g = self.cast(temp_g, mstype.float32) | |||
| matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) | |||
| matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) | |||
| matrix_A_inv_max = self.exp(matrix_A_inv_max) | |||
| temp_a = self.mul(temp_a, matrix_A_inv_max) | |||
| matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) | |||
| matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) | |||
| matrix_G_inv_max = self.exp(matrix_G_inv_max) | |||
| temp_g = self.mul(temp_g, matrix_G_inv_max) | |||
| temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) | |||
| temp_a = F.depend(temp_a, g) | |||
| temp_g = F.depend(temp_g, g) | |||
| temp_a = self.cast(temp_a, mstype.float16) | |||
| temp_g = self.cast(temp_g, mstype.float16) | |||
| g = self.cast(g, mstype.float16) | |||
| g = self.matmul(temp_g, g) | |||
| g = self.matmul(g, temp_a) | |||
| g = self.cast(g, mstype.float32) | |||
| g = self.mul(g, temp_max) | |||
| fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) | |||
| fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | |||
| fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) | |||
| g = F.depend(g, fake_A) | |||
| g = F.depend(g, fake_G) | |||
| g = F.depend(g, fake_max) | |||
| new_grads = new_grads + (g,) | |||
| new_grads = new_grads + (gradients[grad_idx + 1],) | |||
| @@ -216,32 +184,18 @@ class THOR(Optimizer): | |||
| pooler_bias = gradients[pooler_layer_idx + 1] | |||
| temp_a = self.matrix_A[matrix_idx] | |||
| temp_g = self.matrix_G[matrix_idx] | |||
| temp_a = self.cast(temp_a, mstype.float32) | |||
| temp_g = self.cast(temp_g, mstype.float32) | |||
| matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) | |||
| matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) | |||
| matrix_A_inv_max = self.exp(matrix_A_inv_max) | |||
| temp_a = self.mul(temp_a, matrix_A_inv_max) | |||
| matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) | |||
| matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) | |||
| matrix_G_inv_max = self.exp(matrix_G_inv_max) | |||
| temp_g = self.mul(temp_g, matrix_G_inv_max) | |||
| temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) | |||
| temp_a = F.depend(temp_a, g) | |||
| temp_g = F.depend(temp_g, g) | |||
| temp_a = self.cast(temp_a, mstype.float16) | |||
| temp_g = self.cast(temp_g, mstype.float16) | |||
| g = self.cast(g, mstype.float16) | |||
| g = self.matmul(temp_g, g) | |||
| g = self.matmul(g, temp_a) | |||
| g = self.cast(g, mstype.float32) | |||
| g = self.mul(g, temp_max) | |||
| fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) | |||
| fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | |||
| fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) | |||
| g = F.depend(g, fake_A) | |||
| g = F.depend(g, fake_G) | |||
| g = F.depend(g, fake_max) | |||
| new_grads = new_grads + (g, pooler_bias) | |||
| # for cls1 fc layer: mlm | |||
| @@ -251,38 +205,26 @@ class THOR(Optimizer): | |||
| mlm_bias = gradients[mlm_fc_idx + 1] | |||
| temp_a = self.matrix_A[matrix_idx] | |||
| temp_g = self.matrix_G[matrix_idx] | |||
| temp_a = self.cast(temp_a, mstype.float32) | |||
| temp_g = self.cast(temp_g, mstype.float32) | |||
| matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) | |||
| matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) | |||
| matrix_A_inv_max = self.exp(matrix_A_inv_max) | |||
| temp_a = self.mul(temp_a, matrix_A_inv_max) | |||
| matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) | |||
| matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) | |||
| matrix_G_inv_max = self.exp(matrix_G_inv_max) | |||
| temp_g = self.mul(temp_g, matrix_G_inv_max) | |||
| temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) | |||
| temp_a = F.depend(temp_a, g) | |||
| temp_g = F.depend(temp_g, g) | |||
| temp_a = self.cast(temp_a, mstype.float16) | |||
| temp_g = self.cast(temp_g, mstype.float16) | |||
| g = self.cast(g, mstype.float16) | |||
| g = self.matmul(temp_g, g) | |||
| g = self.matmul(g, temp_a) | |||
| g = self.cast(g, mstype.float32) | |||
| g = self.mul(g, temp_max) | |||
| # add bert.cls1.output_bias grad | |||
| fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) | |||
| fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | |||
| fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) | |||
| g = F.depend(g, fake_A) | |||
| g = F.depend(g, fake_G) | |||
| g = F.depend(g, fake_max) | |||
| new_grads = new_grads + (gradients[mlm_fc_idx - 1],) | |||
| new_grads = new_grads + (g, mlm_bias) | |||
| # add bert.cls1.layernorm grad | |||
| begin_idx = mlm_fc_idx + 2 | |||
| end_idx = mlm_fc_idx + 4 | |||
| new_grads = new_grads + gradients[begin_idx: end_idx] | |||
| lenth = len(gradients) | |||
| new_grads = new_grads + gradients[lenth - 2: lenth] | |||
| gradients = new_grads | |||
| @@ -293,15 +235,16 @@ class THOR(Optimizer): | |||
| g = gradients[em_idx] | |||
| matrix_idx = em_idx | |||
| temp_a = self.matrix_A[matrix_idx] | |||
| temp_a = self.expand(temp_a, 1) | |||
| temp_g = self.matrix_G[matrix_idx] | |||
| matrix_max = self.matrix_max_inv[matrix_idx] | |||
| g = self.mul(temp_a, g) | |||
| temp_a = F.depend(temp_a, g) | |||
| temp_g = F.depend(temp_g, g) | |||
| temp_a = self.expand(temp_a, 1) | |||
| temp_a = self.cast(temp_a, mstype.float16) | |||
| temp_g = self.cast(temp_g, mstype.float16) | |||
| g = self.cast(g, mstype.float16) | |||
| g = self.mul(temp_a, g) | |||
| g = self.matmul(g, temp_g) | |||
| g = self.cast(g, mstype.float32) | |||
| g = self.mul(g, matrix_max) | |||
| new_grads = new_grads + (g,) | |||
| # process bert_embedding_postprocessor.layernorm | |||
| grad_idx = 3 | |||
| @@ -356,15 +299,14 @@ class THOR(Optimizer): | |||
| matrix_idx = 6 * i + offset_idx + 3 | |||
| temp_a = self.matrix_A[matrix_idx] | |||
| temp_g = self.matrix_G[matrix_idx] | |||
| matrix_max = self.matrix_max_inv[matrix_idx] | |||
| temp_a = F.depend(temp_a, g) | |||
| temp_g = F.depend(temp_g, g) | |||
| temp_a = self.cast(temp_a, mstype.float16) | |||
| temp_g = self.cast(temp_g, mstype.float16) | |||
| g = self.cast(g, mstype.float16) | |||
| g = self.matmul(temp_g, g) | |||
| g = self.matmul(g, temp_a) | |||
| g = self.cast(g, mstype.float32) | |||
| g = self.mul(g, matrix_max) | |||
| new_grads = new_grads + (g,) | |||
| new_grads = new_grads + (gradients[grad_idx + 1],) | |||
| @@ -375,15 +317,14 @@ class THOR(Optimizer): | |||
| pooler_bias = gradients[pooler_layer_idx + 1] | |||
| temp_a = self.matrix_A[matrix_idx] | |||
| temp_g = self.matrix_G[matrix_idx] | |||
| matrix_max = self.matrix_max_inv[matrix_idx] | |||
| temp_a = F.depend(temp_a, g) | |||
| temp_g = F.depend(temp_g, g) | |||
| temp_a = self.cast(temp_a, mstype.float16) | |||
| temp_g = self.cast(temp_g, mstype.float16) | |||
| g = self.cast(g, mstype.float16) | |||
| g = self.matmul(temp_g, g) | |||
| g = self.matmul(g, temp_a) | |||
| g = self.cast(g, mstype.float32) | |||
| g = self.mul(g, matrix_max) | |||
| new_grads = new_grads + (g, pooler_bias) | |||
| # for cls1 fc layer: mlm | |||
| @@ -393,15 +334,14 @@ class THOR(Optimizer): | |||
| mlm_bias = gradients[mlm_fc_idx + 1] | |||
| temp_a = self.matrix_A[matrix_idx] | |||
| temp_g = self.matrix_G[matrix_idx] | |||
| matrix_max = self.matrix_max_inv[matrix_idx] | |||
| temp_a = F.depend(temp_a, g) | |||
| temp_g = F.depend(temp_g, g) | |||
| temp_a = self.cast(temp_a, mstype.float16) | |||
| temp_g = self.cast(temp_g, mstype.float16) | |||
| g = self.cast(g, mstype.float16) | |||
| g = self.matmul(temp_g, g) | |||
| g = self.matmul(g, temp_a) | |||
| g = self.cast(g, mstype.float32) | |||
| g = self.mul(g, matrix_max) | |||
| # add bert.cls1.output_bias grad | |||
| new_grads = new_grads + (gradients[mlm_fc_idx - 1],) | |||
| new_grads = new_grads + (g, mlm_bias) | |||
| @@ -409,6 +349,7 @@ class THOR(Optimizer): | |||
| begin_idx = mlm_fc_idx + 2 | |||
| end_idx = mlm_fc_idx + 4 | |||
| new_grads = new_grads + gradients[begin_idx: end_idx] | |||
| lenth = len(gradients) | |||
| new_grads = new_grads + gradients[lenth - 2: lenth] | |||
| gradients = new_grads | |||
| @@ -21,7 +21,7 @@ from mindspore.common.tensor import Tensor | |||
| from mindspore.nn.optim.optimizer import Optimizer | |||
| from mindspore.ops import functional as F, composite as C, operations as P | |||
| from mindspore.parallel._utils import _get_device_num, _get_mirror_mean | |||
| from .grad_reducer_thor1 import DistributedGradReducerThor1 | |||
| from .grad_reducer_thor import DistributedGradReducerThor | |||
| momentum_opt = C.MultitypeFuncGraph("momentum_opt") | |||
| @@ -48,9 +48,8 @@ def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): | |||
| class THOR(Optimizer): | |||
| """THOR""" | |||
| def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0, | |||
| loss_scale=1.0, num_hidden_layers=24, batch_size=12, damping=0.03, frequency=10, | |||
| def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, weight_decay=0.0, | |||
| loss_scale=1.0, num_hidden_layers=24, batch_size=12, damping=0.03, | |||
| decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()): | |||
| super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale) | |||
| if isinstance(momentum, float) and momentum < 0.0: | |||
| @@ -62,8 +61,6 @@ class THOR(Optimizer): | |||
| self.opt = P.ApplyMomentum() | |||
| self.matrix_A = ParameterTuple(matrix_A) | |||
| self.matrix_G = ParameterTuple(matrix_G) | |||
| self.A_inv_max = ParameterTuple(A_inv_max) | |||
| self.G_inv_max = ParameterTuple(G_inv_max) | |||
| self.matmul = P.MatMul() | |||
| self.transpose = P.Transpose() | |||
| self.shape = P.Shape() | |||
| @@ -72,16 +69,8 @@ class THOR(Optimizer): | |||
| self.gather = P.GatherV2() | |||
| self.matrix_A_inv = () | |||
| self.matrix_G_inv = () | |||
| self.matrix_max_inv = () | |||
| self.num_hidden_layers = num_hidden_layers | |||
| fc_layer_num = num_hidden_layers * 6 + 5 | |||
| for i in range(fc_layer_num): | |||
| self.matrix_max_inv = self.matrix_max_inv + ( | |||
| Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),) | |||
| self.log = P.Log() | |||
| self.exp = P.Exp() | |||
| self.sqrt = P.Sqrt() | |||
| self.matrix_max_inv = ParameterTuple(self.matrix_max_inv) | |||
| self.assign = P.Assign() | |||
| self.cast = P.Cast() | |||
| self.thor = True | |||
| @@ -92,12 +81,11 @@ class THOR(Optimizer): | |||
| self.inv = P.Inv() | |||
| self.batch_size = batch_size | |||
| self.damping = damping | |||
| self.freq = Tensor(frequency, mstype.int32) | |||
| self.one = Tensor(1, mstype.int32) | |||
| self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False) | |||
| mean = _get_mirror_mean() | |||
| degree = _get_device_num() | |||
| self.grad_reducer_g = DistributedGradReducerThor1(self.parameters, 3, mean, degree) | |||
| self.grad_reducer_g = DistributedGradReducerThor(self.parameters, 3, mean, degree) | |||
| def construct(self, gradients): | |||
| """construct of THOR""" | |||
| @@ -111,26 +99,20 @@ class THOR(Optimizer): | |||
| g = gradients[em_idx] | |||
| matrix_idx = em_idx | |||
| temp_a_ori = self.matrix_A[matrix_idx] | |||
| temp_a = self.expand(temp_a_ori, 1) | |||
| temp_g = self.matrix_G[matrix_idx] | |||
| G_max = self.G_inv_max[matrix_idx] | |||
| temp_g = self.cast(temp_g, mstype.float32) | |||
| matrix_G_inv_max = self.log(G_max) | |||
| matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) | |||
| matrix_G_inv_max = self.exp(matrix_G_inv_max) | |||
| temp_g = self.mul(temp_g, matrix_G_inv_max) | |||
| g = self.mul(temp_a, g) | |||
| g = self.cast(g, mstype.float16) | |||
| temp_a_ori = F.depend(temp_a_ori, g) | |||
| temp_g = F.depend(temp_g, g) | |||
| temp_a = self.expand(temp_a_ori, 1) | |||
| temp_a = self.cast(temp_a, mstype.float16) | |||
| temp_g = self.cast(temp_g, mstype.float16) | |||
| g = self.cast(g, mstype.float16) | |||
| g = self.mul(temp_a, g) | |||
| g = self.matmul(g, temp_g) | |||
| g = self.cast(g, mstype.float32) | |||
| g = self.mul(g, G_max) | |||
| fake_A = self.assign(self.matrix_A[matrix_idx], temp_a_ori) | |||
| fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | |||
| fake_max = self.assign(self.matrix_max_inv[matrix_idx], G_max) | |||
| g = F.depend(g, fake_A) | |||
| g = F.depend(g, fake_G) | |||
| g = F.depend(g, fake_max) | |||
| new_grads = new_grads + (g,) | |||
| # process bert_embedding_postprocessor.layernorm | |||
| grad_idx = 3 | |||
| @@ -185,32 +167,18 @@ class THOR(Optimizer): | |||
| matrix_idx = 6 * i + offset_idx + 3 | |||
| temp_a = self.matrix_A[matrix_idx] | |||
| temp_g = self.matrix_G[matrix_idx] | |||
| temp_a = self.cast(temp_a, mstype.float32) | |||
| temp_g = self.cast(temp_g, mstype.float32) | |||
| matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) | |||
| matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) | |||
| matrix_A_inv_max = self.exp(matrix_A_inv_max) | |||
| temp_a = self.mul(temp_a, matrix_A_inv_max) | |||
| matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) | |||
| matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) | |||
| matrix_G_inv_max = self.exp(matrix_G_inv_max) | |||
| temp_g = self.mul(temp_g, matrix_G_inv_max) | |||
| temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) | |||
| temp_a = F.depend(temp_a, g) | |||
| temp_g = F.depend(temp_g, g) | |||
| temp_a = self.cast(temp_a, mstype.float16) | |||
| temp_g = self.cast(temp_g, mstype.float16) | |||
| g = self.cast(g, mstype.float16) | |||
| g = self.matmul(temp_g, g) | |||
| g = self.matmul(g, temp_a) | |||
| g = self.cast(g, mstype.float32) | |||
| g = self.mul(g, temp_max) | |||
| fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) | |||
| fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | |||
| fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) | |||
| g = F.depend(g, fake_A) | |||
| g = F.depend(g, fake_G) | |||
| g = F.depend(g, fake_max) | |||
| new_grads = new_grads + (g,) | |||
| new_grads = new_grads + (gradients[grad_idx + 1],) | |||
| @@ -221,32 +189,18 @@ class THOR(Optimizer): | |||
| pooler_bias = gradients[pooler_layer_idx + 1] | |||
| temp_a = self.matrix_A[matrix_idx] | |||
| temp_g = self.matrix_G[matrix_idx] | |||
| temp_a = self.cast(temp_a, mstype.float32) | |||
| temp_g = self.cast(temp_g, mstype.float32) | |||
| matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) | |||
| matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) | |||
| matrix_A_inv_max = self.exp(matrix_A_inv_max) | |||
| temp_a = self.mul(temp_a, matrix_A_inv_max) | |||
| matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) | |||
| matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) | |||
| matrix_G_inv_max = self.exp(matrix_G_inv_max) | |||
| temp_g = self.mul(temp_g, matrix_G_inv_max) | |||
| temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) | |||
| temp_a = F.depend(temp_a, g) | |||
| temp_g = F.depend(temp_g, g) | |||
| temp_a = self.cast(temp_a, mstype.float16) | |||
| temp_g = self.cast(temp_g, mstype.float16) | |||
| g = self.cast(g, mstype.float16) | |||
| g = self.matmul(temp_g, g) | |||
| g = self.matmul(g, temp_a) | |||
| g = self.cast(g, mstype.float32) | |||
| g = self.mul(g, temp_max) | |||
| fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) | |||
| fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | |||
| fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) | |||
| g = F.depend(g, fake_A) | |||
| g = F.depend(g, fake_G) | |||
| g = F.depend(g, fake_max) | |||
| new_grads = new_grads + (g, pooler_bias) | |||
| # for cls1 fc layer: mlm | |||
| @@ -256,38 +210,26 @@ class THOR(Optimizer): | |||
| mlm_bias = gradients[mlm_fc_idx + 1] | |||
| temp_a = self.matrix_A[matrix_idx] | |||
| temp_g = self.matrix_G[matrix_idx] | |||
| temp_a = self.cast(temp_a, mstype.float32) | |||
| temp_g = self.cast(temp_g, mstype.float32) | |||
| matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx]) | |||
| matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) | |||
| matrix_A_inv_max = self.exp(matrix_A_inv_max) | |||
| temp_a = self.mul(temp_a, matrix_A_inv_max) | |||
| matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx]) | |||
| matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) | |||
| matrix_G_inv_max = self.exp(matrix_G_inv_max) | |||
| temp_g = self.mul(temp_g, matrix_G_inv_max) | |||
| temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx]) | |||
| temp_a = F.depend(temp_a, g) | |||
| temp_g = F.depend(temp_g, g) | |||
| temp_a = self.cast(temp_a, mstype.float16) | |||
| temp_g = self.cast(temp_g, mstype.float16) | |||
| g = self.cast(g, mstype.float16) | |||
| g = self.matmul(temp_g, g) | |||
| g = self.matmul(g, temp_a) | |||
| g = self.cast(g, mstype.float32) | |||
| g = self.mul(g, temp_max) | |||
| # add bert.cls1.output_bias grad | |||
| fake_A = self.assign(self.matrix_A[matrix_idx], temp_a) | |||
| fake_G = self.assign(self.matrix_G[matrix_idx], temp_g) | |||
| fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max) | |||
| g = F.depend(g, fake_A) | |||
| g = F.depend(g, fake_G) | |||
| g = F.depend(g, fake_max) | |||
| new_grads = new_grads + (gradients[mlm_fc_idx - 1],) | |||
| new_grads = new_grads + (g, mlm_bias) | |||
| # add bert.cls1.layernorm grad | |||
| begin_idx = mlm_fc_idx + 2 | |||
| end_idx = mlm_fc_idx + 4 | |||
| new_grads = new_grads + gradients[begin_idx: end_idx] | |||
| lenth = len(gradients) | |||
| new_grads = new_grads + gradients[lenth - 2: lenth] | |||
| gradients = new_grads | |||
| @@ -299,15 +241,16 @@ class THOR(Optimizer): | |||
| g = gradients[em_idx] | |||
| matrix_idx = em_idx | |||
| temp_a = self.matrix_A[matrix_idx] | |||
| temp_a = self.expand(temp_a, 1) | |||
| temp_g = self.matrix_G[matrix_idx] | |||
| matrix_max = self.matrix_max_inv[matrix_idx] | |||
| g = self.mul(temp_a, g) | |||
| temp_a = F.depend(temp_a, g) | |||
| temp_g = F.depend(temp_g, g) | |||
| temp_a = self.expand(temp_a, 1) | |||
| temp_a = self.cast(temp_a, mstype.float16) | |||
| temp_g = self.cast(temp_g, mstype.float16) | |||
| g = self.cast(g, mstype.float16) | |||
| g = self.mul(temp_a, g) | |||
| g = self.matmul(g, temp_g) | |||
| g = self.cast(g, mstype.float32) | |||
| g = self.mul(g, matrix_max) | |||
| new_grads = new_grads + (g,) | |||
| # process bert_embedding_postprocessor.layernorm | |||
| grad_idx = 3 | |||
| @@ -362,15 +305,14 @@ class THOR(Optimizer): | |||
| matrix_idx = 6 * i + offset_idx + 3 | |||
| temp_a = self.matrix_A[matrix_idx] | |||
| temp_g = self.matrix_G[matrix_idx] | |||
| matrix_max = self.matrix_max_inv[matrix_idx] | |||
| temp_a = F.depend(temp_a, g) | |||
| temp_g = F.depend(temp_g, g) | |||
| temp_a = self.cast(temp_a, mstype.float16) | |||
| temp_g = self.cast(temp_g, mstype.float16) | |||
| g = self.cast(g, mstype.float16) | |||
| g = self.matmul(temp_g, g) | |||
| g = self.matmul(g, temp_a) | |||
| g = self.cast(g, mstype.float32) | |||
| g = self.mul(g, matrix_max) | |||
| new_grads = new_grads + (g,) | |||
| new_grads = new_grads + (gradients[grad_idx + 1],) | |||
| @@ -381,15 +323,14 @@ class THOR(Optimizer): | |||
| pooler_bias = gradients[pooler_layer_idx + 1] | |||
| temp_a = self.matrix_A[matrix_idx] | |||
| temp_g = self.matrix_G[matrix_idx] | |||
| matrix_max = self.matrix_max_inv[matrix_idx] | |||
| temp_a = F.depend(temp_a, g) | |||
| temp_g = F.depend(temp_g, g) | |||
| temp_a = self.cast(temp_a, mstype.float16) | |||
| temp_g = self.cast(temp_g, mstype.float16) | |||
| g = self.cast(g, mstype.float16) | |||
| g = self.matmul(temp_g, g) | |||
| g = self.matmul(g, temp_a) | |||
| g = self.cast(g, mstype.float32) | |||
| g = self.mul(g, matrix_max) | |||
| new_grads = new_grads + (g, pooler_bias) | |||
| # for cls1 fc layer: mlm | |||
| @@ -399,15 +340,14 @@ class THOR(Optimizer): | |||
| mlm_bias = gradients[mlm_fc_idx + 1] | |||
| temp_a = self.matrix_A[matrix_idx] | |||
| temp_g = self.matrix_G[matrix_idx] | |||
| matrix_max = self.matrix_max_inv[matrix_idx] | |||
| temp_a = F.depend(temp_a, g) | |||
| temp_g = F.depend(temp_g, g) | |||
| temp_a = self.cast(temp_a, mstype.float16) | |||
| temp_g = self.cast(temp_g, mstype.float16) | |||
| g = self.cast(g, mstype.float16) | |||
| g = self.matmul(temp_g, g) | |||
| g = self.matmul(g, temp_a) | |||
| g = self.cast(g, mstype.float32) | |||
| g = self.mul(g, matrix_max) | |||
| # add bert.cls1.output_bias grad | |||
| new_grads = new_grads + (gradients[mlm_fc_idx - 1],) | |||
| new_grads = new_grads + (g, mlm_bias) | |||
| @@ -415,6 +355,7 @@ class THOR(Optimizer): | |||
| begin_idx = mlm_fc_idx + 2 | |||
| end_idx = mlm_fc_idx + 4 | |||
| new_grads = new_grads + gradients[begin_idx: end_idx] | |||
| lenth = len(gradients) | |||
| new_grads = new_grads + gradients[lenth - 2: lenth] | |||
| gradients = new_grads | |||
| @@ -14,7 +14,6 @@ | |||
| # ============================================================================ | |||
| """thor_layer""" | |||
| import numpy as np | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore._checkparam import check_bool, check_int_positive | |||
| from mindspore.common.initializer import TruncatedNormal, initializer | |||
| @@ -24,7 +23,6 @@ from mindspore.nn.cell import Cell | |||
| from mindspore.nn.layer.activation import get_activation | |||
| from mindspore.ops import operations as P | |||
| class Embedding_Thor(Cell): | |||
| """ | |||
| A embeddings lookup table with a fixed dictionary and size. | |||
| @@ -37,7 +35,6 @@ class Embedding_Thor(Cell): | |||
| use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. | |||
| initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. | |||
| """ | |||
| def __init__(self, | |||
| vocab_size, | |||
| embedding_size, | |||
| @@ -45,11 +42,10 @@ class Embedding_Thor(Cell): | |||
| use_one_hot_embeddings=False, | |||
| initializer_range=0.02, | |||
| name='embedding_table', | |||
| is_expand=False, | |||
| batch_size=12, | |||
| damping=0.03, | |||
| loss_scale=1, | |||
| frequency=10, | |||
| frequency=100, | |||
| ): | |||
| super(Embedding_Thor, self).__init__() | |||
| self.vocab_size = vocab_size | |||
| @@ -59,7 +55,6 @@ class Embedding_Thor(Cell): | |||
| [vocab_size, embedding_size]), | |||
| name=name) | |||
| self.thor = True | |||
| self.is_expand = is_expand | |||
| self.expand = P.ExpandDims() | |||
| self.shape_flat = (-1,) | |||
| self.gather = P.GatherV2() | |||
| @@ -71,13 +66,11 @@ class Embedding_Thor(Cell): | |||
| self.em_shape = tuple(embedding_shape) | |||
| self.shape = P.Shape() | |||
| self.loss_scale = Tensor(1 / loss_scale, mstype.float16) | |||
| self.matrix_A_inv = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)), name='matrix_A_inv', | |||
| requires_grad=False) | |||
| self.matrix_A_inv = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float16)), | |||
| name='matrix_A_inv', requires_grad=False) | |||
| self.matrix_G_inv = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float16)), | |||
| name="matrix_G_inv", requires_grad=False) | |||
| self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False) | |||
| self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False) | |||
| self.fused_abs_max = P.CusFusedAbsMax1() | |||
| self.fake_G = Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float16)) | |||
| self.dampingA = Tensor(np.ones([vocab_size]).astype(np.float32)) | |||
| self.dampingG = Tensor(np.identity(embedding_size), mstype.float32) | |||
| @@ -117,9 +110,6 @@ class Embedding_Thor(Cell): | |||
| matrix_G = matrix_G + damping * dampingG | |||
| matrix_G_inv = self.cholesky(matrix_G) | |||
| matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv) | |||
| matrix_G_inv_max = self.fused_abs_max(matrix_G_inv) | |||
| matrix_G_inv_max = self.fused_abs_max(matrix_G_inv_max) | |||
| self.G_inv_max = matrix_G_inv_max | |||
| matrix_G_inv = self.matrix_combine(matrix_G_inv) | |||
| matrix_G_inv = self.cast(matrix_G_inv, mstype.float16) | |||
| self.matrix_G_inv = matrix_G_inv | |||
| @@ -127,8 +117,6 @@ class Embedding_Thor(Cell): | |||
| def construct(self, input_ids): | |||
| """construct of Embedding_Thor""" | |||
| if self.is_expand: | |||
| input_ids = self.expand(input_ids, -1) | |||
| flat_ids = self.reshape(input_ids, self.shape_flat) | |||
| if self.use_one_hot_embeddings: | |||
| one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) | |||
| @@ -146,6 +134,7 @@ class Embedding_Thor(Cell): | |||
| dampingA = self.cast(self.dampingA, mstype.float32) | |||
| matrix_A = matrix_A + damping * dampingA | |||
| matrix_A_inv = self.inv(matrix_A) | |||
| matrix_A_inv = self.cast(matrix_A_inv, mstype.float16) | |||
| self.matrix_A_inv = matrix_A_inv | |||
| self.matrix_G_inv = self.fake_G | |||
| output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) | |||
| @@ -156,11 +145,9 @@ class Embedding_Thor(Cell): | |||
| output = self.reshape(output_for_reshape, self.em_shape) | |||
| return output, self.embedding_table | |||
| class Dense_Thor(Cell): | |||
| """Dense_Thor""" | |||
| # @cell_attr_register(attrs=['has_bias', 'activation', 'in_channels', 'out_channels']) | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| @@ -168,7 +155,7 @@ class Dense_Thor(Cell): | |||
| bias_init='zeros', | |||
| damping=0.03, | |||
| loss_scale=1, | |||
| frequency=10, | |||
| frequency=100, | |||
| has_bias=False, | |||
| activation=None, | |||
| batch_size=12): | |||
| @@ -200,9 +187,6 @@ class Dense_Thor(Cell): | |||
| name='matrix_A_inv', requires_grad=False) | |||
| self.matrix_G_inv = Parameter(Tensor(np.zeros([out_channels, out_channels]).astype(np.float16)), | |||
| name="matrix_G_inv", requires_grad=False) | |||
| self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False) | |||
| self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False) | |||
| self.fused_abs_max = P.CusFusedAbsMax1() | |||
| self.fake_G = Tensor(np.zeros([out_channels, out_channels]).astype(np.float16)) | |||
| self.matmul = P.MatMul(transpose_b=True) | |||
| @@ -250,9 +234,6 @@ class Dense_Thor(Cell): | |||
| matrix_G = matrix_G + damping * dampingG | |||
| matrix_G_inv = self.cholesky(matrix_G) | |||
| matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv) | |||
| matrix_G_inv_max = self.fused_abs_max(matrix_G_inv) | |||
| matrix_G_inv_max = self.fused_abs_max(matrix_G_inv_max) | |||
| self.G_inv_max = matrix_G_inv_max | |||
| matrix_G_inv = self.matrix_combine(matrix_G_inv) | |||
| matrix_G_inv = self.cast(matrix_G_inv, mstype.float16) | |||
| self.matrix_G_inv = matrix_G_inv | |||
| @@ -265,7 +246,6 @@ class Dense_Thor(Cell): | |||
| shape = self.shape(x) | |||
| normalizer = self.cast(shape[0], mstype.float32) | |||
| matrix_A = self.mul(inputs, 1.0 / normalizer) | |||
| damping_step = self.gather(self.damping, self.cov_step, self.axis) | |||
| damping_step = self.cast(damping_step, mstype.float32) | |||
| damping = self.sqrt(damping_step) | |||
| @@ -273,9 +253,6 @@ class Dense_Thor(Cell): | |||
| matrix_A = matrix_A + damping * dampingA | |||
| matrix_A_inv = self.cholesky(matrix_A) | |||
| matrix_A_inv = self.vector_matmul(matrix_A_inv, matrix_A_inv) | |||
| matrix_A_inv_max = self.fused_abs_max(matrix_A_inv) | |||
| matrix_A_inv_max = self.fused_abs_max(matrix_A_inv_max) | |||
| self.A_inv_max = matrix_A_inv_max | |||
| matrix_A_inv = self.matrix_combine(matrix_A_inv) | |||
| matrix_A_inv = self.cast(matrix_A_inv, mstype.float16) | |||
| self.matrix_A_inv = matrix_A_inv | |||