Browse Source

modify Transformer model

tags/v0.5.0-beta
yuchaojie 5 years ago
parent
commit
a9b7861a00
3 changed files with 39 additions and 51 deletions
  1. +1
    -2
      model_zoo/Transformer/eval.py
  2. +29
    -28
      model_zoo/Transformer/src/transformer_model.py
  3. +9
    -21
      model_zoo/Transformer/train.py

+ 1
- 2
model_zoo/Transformer/eval.py View File

@@ -78,9 +78,8 @@ def load_weights(model_path):


weights = {} weights = {}
for msname in ms_ckpt: for msname in ms_ckpt:
infer_name = msname.replace("transformer.transformer.", "")
infer_name = msname
if "tfm_decoder" in msname: if "tfm_decoder" in msname:
infer_name = infer_name.replace(".layers.", ".layer")
infer_name = "tfm_decoder.decoder." + infer_name infer_name = "tfm_decoder.decoder." + infer_name
if is_npz: if is_npz:
weights[infer_name] = ms_ckpt[msname] weights[infer_name] = ms_ckpt[msname]


+ 29
- 28
model_zoo/Transformer/src/transformer_model.py View File

@@ -20,11 +20,11 @@ import numpy as np
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.ops.functional as F import mindspore.ops.functional as F
from mindspore.common.initializer import TruncatedNormal, initializer
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from .beam_search import BeamSearchDecoder, TileBeam from .beam_search import BeamSearchDecoder, TileBeam
from .weight_init import normal_weight, weight_variable


class TransformerConfig: class TransformerConfig:
""" """
@@ -118,9 +118,7 @@ class EmbeddingLookup(nn.Cell):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.embedding_size = embedding_size self.embedding_size = embedding_size
self.use_one_hot_embeddings = use_one_hot_embeddings self.use_one_hot_embeddings = use_one_hot_embeddings
self.embedding_table = Parameter(initializer
(TruncatedNormal(initializer_range),
[vocab_size, embedding_size]),
self.embedding_table = Parameter(normal_weight([vocab_size, embedding_size], embedding_size),
name='embedding_table') name='embedding_table')
self.expand = P.ExpandDims() self.expand = P.ExpandDims()
self.shape_flat = (-1,) self.shape_flat = (-1,)
@@ -138,8 +136,7 @@ class EmbeddingLookup(nn.Cell):
flat_ids = self.reshape(input_ids, self.shape_flat) flat_ids = self.reshape(input_ids, self.shape_flat)
if self.use_one_hot_embeddings: if self.use_one_hot_embeddings:
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
output_for_reshape = self.array_mul(
one_hot_ids, self.embedding_table)
output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table)
else: else:
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)


@@ -329,22 +326,22 @@ class MultiheadAttention(nn.Cell):
units, units,
activation=query_act, activation=query_act,
has_bias=False, has_bias=False,
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
weight_init=weight_variable([units, from_tensor_width])).to_float(compute_type)
self.key_layer = nn.Dense(to_tensor_width, self.key_layer = nn.Dense(to_tensor_width,
units, units,
activation=key_act, activation=key_act,
has_bias=False, has_bias=False,
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
weight_init=weight_variable([units, to_tensor_width])).to_float(compute_type)
self.value_layer = nn.Dense(to_tensor_width, self.value_layer = nn.Dense(to_tensor_width,
units, units,
activation=value_act, activation=value_act,
has_bias=False, has_bias=False,
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
weight_init=weight_variable([units, to_tensor_width])).to_float(compute_type)
self.out_layer = nn.Dense(units, self.out_layer = nn.Dense(units,
out_tensor_width, out_tensor_width,
activation=out_act, activation=out_act,
has_bias=False, has_bias=False,
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
weight_init=weight_variable([out_tensor_width, units])).to_float(compute_type)


self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head) self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head)
self.shape_to = (batch_size, to_seq_length, num_attention_heads, size_per_head) self.shape_to = (batch_size, to_seq_length, num_attention_heads, size_per_head)
@@ -518,10 +515,10 @@ class FeedForward(nn.Cell):
self.conv1 = nn.Dense(in_channels, self.conv1 = nn.Dense(in_channels,
hidden_size, hidden_size,
activation=hidden_act, activation=hidden_act,
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
weight_init=weight_variable([hidden_size, in_channels])).to_float(compute_type)
self.conv2 = nn.Dense(hidden_size, self.conv2 = nn.Dense(hidden_size,
out_channels, out_channels,
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
weight_init=weight_variable([out_channels, hidden_size])).to_float(compute_type)


self.preprocess = LayerPreprocess(in_channels=in_channels) self.preprocess = LayerPreprocess(in_channels=in_channels)
self.postprocess = LayerPostprocess(dropout_prob=hidden_dropout_prob) self.postprocess = LayerPostprocess(dropout_prob=hidden_dropout_prob)
@@ -1108,7 +1105,13 @@ class TransformerModel(nn.Cell):
embedding_size=self.embedding_size, embedding_size=self.embedding_size,
use_one_hot_embeddings=use_one_hot_embeddings, use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=config.initializer_range) initializer_range=config.initializer_range)
self.tfm_embedding_postprocessor = EmbeddingPostprocessor(
self.tfm_embedding_postprocessor_for_encoder = EmbeddingPostprocessor(
embedding_size=self.embedding_size,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=0.02,
max_position_embeddings=config.max_position_embeddings,
dropout_prob=config.hidden_dropout_prob)
self.tfm_embedding_postprocessor_for_decoder = EmbeddingPostprocessor(
embedding_size=self.embedding_size, embedding_size=self.embedding_size,
use_one_hot_embeddings=use_one_hot_embeddings, use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=0.02, initializer_range=0.02,
@@ -1171,7 +1174,7 @@ class TransformerModel(nn.Cell):
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
compute_type=config.compute_type, compute_type=config.compute_type,
embedding_lookup=self.tfm_embedding_lookup, embedding_lookup=self.tfm_embedding_lookup,
embedding_processor=self.tfm_embedding_postprocessor,
embedding_processor=self.tfm_embedding_postprocessor_for_decoder,
projection=self.projection) projection=self.projection)
self.tfm_decoder = BeamSearchDecoder( self.tfm_decoder = BeamSearchDecoder(
batch_size=config.batch_size, batch_size=config.batch_size,
@@ -1195,15 +1198,14 @@ class TransformerModel(nn.Cell):
ones = np.ones(shape=(self.seq_length, self.seq_length)) ones = np.ones(shape=(self.seq_length, self.seq_length))
self.future_mask = Tensor(np.tril(ones), dtype=mstype.float32) self.future_mask = Tensor(np.tril(ones), dtype=mstype.float32)
else: else:
self.tile_beam = TileBeam(
beam_width=config.beam_width)
self.tile_beam = TileBeam(beam_width=config.beam_width)
ones = np.ones(shape=(config.batch_size, config.max_decode_length)) ones = np.ones(shape=(config.batch_size, config.max_decode_length))
self.encdec_mask = Tensor(ones, dtype=mstype.float32) self.encdec_mask = Tensor(ones, dtype=mstype.float32)


def construct(self, source_ids, source_mask, target_ids=None, target_mask=None): def construct(self, source_ids, source_mask, target_ids=None, target_mask=None):
# process source sentence # process source sentence
src_word_embeddings, embedding_tables = self.tfm_embedding_lookup(source_ids) src_word_embeddings, embedding_tables = self.tfm_embedding_lookup(source_ids)
src_embedding_output = self.tfm_embedding_postprocessor(src_word_embeddings)
src_embedding_output = self.tfm_embedding_postprocessor_for_encoder(src_word_embeddings)
# attention mask [batch_size, seq_length, seq_length] # attention mask [batch_size, seq_length, seq_length]
enc_attention_mask = self._create_attention_mask_from_input_mask(source_mask) enc_attention_mask = self._create_attention_mask_from_input_mask(source_mask)
# transformer encoder # transformer encoder
@@ -1213,7 +1215,7 @@ class TransformerModel(nn.Cell):
if self.is_training: if self.is_training:
# process target sentence # process target sentence
tgt_word_embeddings, _ = self.tfm_embedding_lookup(target_ids) tgt_word_embeddings, _ = self.tfm_embedding_lookup(target_ids)
tgt_embedding_output = self.tfm_embedding_postprocessor(tgt_word_embeddings)
tgt_embedding_output = self.tfm_embedding_postprocessor_for_decoder(tgt_word_embeddings)
# attention mask [batch_size, seq_length, seq_length] # attention mask [batch_size, seq_length, seq_length]
tgt_attention_mask = self._create_attention_mask_from_input_mask(target_mask) tgt_attention_mask = self._create_attention_mask_from_input_mask(target_mask)
tgt_attention_mask = self.multiply(tgt_attention_mask, self.expand(self.future_mask, 0)) tgt_attention_mask = self.multiply(tgt_attention_mask, self.expand(self.future_mask, 0))
@@ -1223,15 +1225,14 @@ class TransformerModel(nn.Cell):
encoder_output, enc_attention_mask) encoder_output, enc_attention_mask)
# calculate logits and log_probs # calculate logits and log_probs
log_probs = self.projection(decoder_output, embedding_tables) log_probs = self.projection(decoder_output, embedding_tables)
return log_probs
beam_encoder_output = self.tile_beam(encoder_output)
ret = log_probs
else:
beam_encoder_output = self.tile_beam(encoder_output)


enc_attention_mask = self.multiply(
enc_attention_mask[::, 0:1:1, ::],
self.expand(self.encdec_mask, -1))
enc_attention_mask = self.multiply(enc_attention_mask[::, 0:1:1, ::], self.expand(self.encdec_mask, -1))


beam_enc_attention_mask = self.tile_beam(enc_attention_mask)
beam_enc_attention_mask = self.cast_compute_type(beam_enc_attention_mask)
predicted_ids = self.tfm_decoder(beam_encoder_output, beam_enc_attention_mask)
return predicted_ids
beam_enc_attention_mask = self.tile_beam(enc_attention_mask)
beam_enc_attention_mask = self.cast_compute_type(beam_enc_attention_mask)
predicted_ids = self.tfm_decoder(beam_encoder_output, beam_enc_attention_mask)
ret = predicted_ids
return ret

+ 9
- 21
model_zoo/Transformer/train.py View File

@@ -16,9 +16,9 @@


import time import time
import argparse import argparse
import numpy as np


import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.nn.optim import Adam from mindspore.nn.optim import Adam
from mindspore.train.model import Model from mindspore.train.model import Model
@@ -34,7 +34,6 @@ from src.transformer_for_train import TransformerTrainOneStepCell, TransformerNe
TransformerTrainOneStepWithLossScaleCell TransformerTrainOneStepWithLossScaleCell
from src.config import cfg, transformer_net_cfg from src.config import cfg, transformer_net_cfg
from src.dataset import create_transformer_dataset from src.dataset import create_transformer_dataset
from src.weight_init import weight_variable, one_weight, zero_weight, normal_weight
from src.lr_schedule import create_dynamic_lr from src.lr_schedule import create_dynamic_lr




@@ -108,7 +107,7 @@ def run_transformer_train():
parser = argparse_init() parser = argparse_init()
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
context.set_context(save_graphs=True, reserve_class_name_in_scope=False, enable_auto_mixed_precision=False)
context.set_context(reserve_class_name_in_scope=False, enable_auto_mixed_precision=False)


if args.distribute == "true": if args.distribute == "true":
device_num = args.device_num device_num = args.device_num
@@ -129,29 +128,15 @@ def run_transformer_train():


if args.checkpoint_path: if args.checkpoint_path:
parameter_dict = load_checkpoint(args.checkpoint_path) parameter_dict = load_checkpoint(args.checkpoint_path)
else:
parameter_dict = {}
params = netwithloss.trainable_params()
for param in params:
name = param.name
value = param.default_input
if isinstance(value, Tensor):
if name.endswith(".gamma"):
parameter_dict[name] = Parameter(one_weight(value.asnumpy().shape), name=name)
elif name.endswith(".beta") or name.endswith(".bias"):
parameter_dict[name] = Parameter(zero_weight(value.asnumpy().shape), name=name)
elif "embedding" in name:
parameter_dict[name] = Parameter(normal_weight(value.asnumpy().shape,
transformer_net_cfg.hidden_size), name=name)
else:
parameter_dict[name] = Parameter(weight_variable(value.asnumpy().shape), name=name)
load_param_into_net(netwithloss, parameter_dict)
load_param_into_net(netwithloss, parameter_dict)


lr = Tensor(create_dynamic_lr(schedule="constant*rsqrt_hidden*linear_warmup*rsqrt_decay", lr = Tensor(create_dynamic_lr(schedule="constant*rsqrt_hidden*linear_warmup*rsqrt_decay",
training_steps=dataset.get_dataset_size()*args.epoch_size, training_steps=dataset.get_dataset_size()*args.epoch_size,
learning_rate=cfg.lr_schedule.learning_rate, learning_rate=cfg.lr_schedule.learning_rate,
warmup_steps=cfg.lr_schedule.warmup_steps, warmup_steps=cfg.lr_schedule.warmup_steps,
hidden_size=transformer_net_cfg.hidden_size), mstype.float32)
hidden_size=transformer_net_cfg.hidden_size,
start_decay_step=cfg.lr_schedule.start_decay_step,
min_lr=cfg.lr_schedule.min_lr), mstype.float32)
optimizer = Adam(netwithloss.trainable_params(), lr) optimizer = Adam(netwithloss.trainable_params(), lr)


callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack()] callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack()]
@@ -176,4 +161,7 @@ def run_transformer_train():
model.train(repeat_count, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true")) model.train(repeat_count, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true"))


if __name__ == '__main__': if __name__ == '__main__':
random_seed = 1
np.random.seed(random_seed)

run_transformer_train() run_transformer_train()

Loading…
Cancel
Save