|
|
|
@@ -54,13 +54,12 @@ def load_test_data(batch_size=1): |
|
|
|
return ret |
|
|
|
|
|
|
|
|
|
|
|
def get_config(version='base', batch_size=1): |
|
|
|
def get_config(version='base'): |
|
|
|
""" |
|
|
|
get_config definition |
|
|
|
""" |
|
|
|
if version == 'base': |
|
|
|
return BertConfig( |
|
|
|
batch_size=batch_size, |
|
|
|
seq_length=128, |
|
|
|
vocab_size=21128, |
|
|
|
hidden_size=768, |
|
|
|
@@ -74,13 +73,10 @@ def get_config(version='base', batch_size=1): |
|
|
|
type_vocab_size=2, |
|
|
|
initializer_range=0.02, |
|
|
|
use_relative_positions=True, |
|
|
|
input_mask_from_dataset=True, |
|
|
|
token_type_ids_from_dataset=True, |
|
|
|
dtype=mstype.float32, |
|
|
|
compute_type=mstype.float32) |
|
|
|
if version == 'large': |
|
|
|
return BertConfig( |
|
|
|
batch_size=batch_size, |
|
|
|
seq_length=128, |
|
|
|
vocab_size=21128, |
|
|
|
hidden_size=1024, |
|
|
|
@@ -94,11 +90,9 @@ def get_config(version='base', batch_size=1): |
|
|
|
type_vocab_size=2, |
|
|
|
initializer_range=0.02, |
|
|
|
use_relative_positions=True, |
|
|
|
input_mask_from_dataset=True, |
|
|
|
token_type_ids_from_dataset=True, |
|
|
|
dtype=mstype.float32, |
|
|
|
compute_type=mstype.float32) |
|
|
|
return BertConfig(batch_size=batch_size) |
|
|
|
return BertConfig() |
|
|
|
|
|
|
|
|
|
|
|
class BertLearningRate(lr_schedules.LearningRateSchedule): |
|
|
|
@@ -143,7 +137,7 @@ def test_bert_train(): |
|
|
|
batch_size = int(os.getenv('BATCH_SIZE', '1')) |
|
|
|
inputs = load_test_data(batch_size) |
|
|
|
|
|
|
|
config = get_config(version=version, batch_size=batch_size) |
|
|
|
config = get_config(version=version) |
|
|
|
netwithloss = BertNetworkWithLoss(config, True) |
|
|
|
lr = BertLearningRate(10) |
|
|
|
optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) |
|
|
|
@@ -168,7 +162,7 @@ def test_bert_withlossscale_train(): |
|
|
|
scaling_sens = Tensor(np.ones([1]).astype(np.float32)) |
|
|
|
inputs = load_test_data(batch_size) + (scaling_sens,) |
|
|
|
|
|
|
|
config = get_config(version=version, batch_size=batch_size) |
|
|
|
config = get_config(version=version) |
|
|
|
netwithloss = BertNetworkWithLoss(config, True) |
|
|
|
lr = BertLearningRate(10) |
|
|
|
optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) |
|
|
|
@@ -195,7 +189,7 @@ def bert_withlossscale_manager_train(): |
|
|
|
batch_size = int(os.getenv('BATCH_SIZE', '1')) |
|
|
|
inputs = load_test_data(batch_size) |
|
|
|
|
|
|
|
config = get_config(version=version, batch_size=batch_size) |
|
|
|
config = get_config(version=version) |
|
|
|
netwithloss = BertNetworkWithLoss(config, True) |
|
|
|
lr = BertLearningRate(10) |
|
|
|
optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) |
|
|
|
@@ -223,7 +217,7 @@ def bert_withlossscale_manager_train_feed(): |
|
|
|
scaling_sens = Tensor(np.ones([1]).astype(np.float32)) |
|
|
|
inputs = load_test_data(batch_size) + (scaling_sens,) |
|
|
|
|
|
|
|
config = get_config(version=version, batch_size=batch_size) |
|
|
|
config = get_config(version=version) |
|
|
|
netwithloss = BertNetworkWithLoss(config, True) |
|
|
|
lr = BertLearningRate(10) |
|
|
|
optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) |
|
|
|
|