|
|
|
@@ -76,26 +76,6 @@ def get_config(version='base', batch_size=1): |
|
|
|
token_type_ids_from_dataset=True, |
|
|
|
dtype=mstype.float32, |
|
|
|
compute_type=mstype.float16) |
|
|
|
elif version == 'large_mixed': |
|
|
|
bert_config = BertConfig( |
|
|
|
batch_size=batch_size, |
|
|
|
seq_length=128, |
|
|
|
vocab_size=21136, |
|
|
|
hidden_size=1024, |
|
|
|
num_hidden_layers=24, |
|
|
|
num_attention_heads=16, |
|
|
|
intermediate_size=4096, |
|
|
|
hidden_act="gelu", |
|
|
|
hidden_dropout_prob=0.0, |
|
|
|
attention_probs_dropout_prob=0.0, |
|
|
|
max_position_embeddings=512, |
|
|
|
type_vocab_size=2, |
|
|
|
initializer_range=0.02, |
|
|
|
use_relative_positions=True, |
|
|
|
input_mask_from_dataset=True, |
|
|
|
token_type_ids_from_dataset=True, |
|
|
|
dtype=mstype.float32, |
|
|
|
compute_type=mstype.float32) |
|
|
|
else: |
|
|
|
bert_config = BertConfig(batch_size=batch_size) |
|
|
|
return bert_config |
|
|
|
@@ -136,8 +116,8 @@ class ModelCallback(Callback): |
|
|
|
def step_end(self, run_context): |
|
|
|
cb_params = run_context.original_args() |
|
|
|
self.loss_list.append(cb_params.net_outputs[0].asnumpy()[0]) |
|
|
|
self.overflow_list.append(cb_params.net_outputs[1]) |
|
|
|
self.lossscale_list.append(cb_params.net_outputs[2]) |
|
|
|
self.overflow_list.append(cb_params.net_outputs[1].asnumpy()) |
|
|
|
self.lossscale_list.append(cb_params.net_outputs[2].asnumpy()) |
|
|
|
print("epoch: {}, outputs are: {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs))) |
|
|
|
|
|
|
|
@pytest.mark.level0 |
|
|
|
@@ -157,7 +137,7 @@ def test_bert_tdt(): |
|
|
|
netwithloss = BertNetworkWithLoss(config, True) |
|
|
|
optimizer = Momentum(netwithloss.trainable_params(), learning_rate=2e-5, momentum=0.9) |
|
|
|
scale_window = 3 |
|
|
|
scale_manager = DynamicLossScaleManager(2**32, 2, scale_window) |
|
|
|
scale_manager = DynamicLossScaleManager(2**16, 2, scale_window) |
|
|
|
netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=scale_manager.get_update_cell()) |
|
|
|
netwithgrads.set_train(True) |
|
|
|
model = Model(netwithgrads) |
|
|
|
@@ -182,22 +162,21 @@ def test_bert_tdt(): |
|
|
|
param.default_input = weight_variable(value.asnumpy().shape) |
|
|
|
model.train(ds.get_repeat_count(), ds, callbacks=callback, dataset_sink_mode=False) |
|
|
|
|
|
|
|
# assertion occurs while the loss_scale value is wrong |
|
|
|
count = 0 |
|
|
|
for i in range(len(callback.overflow_list)): |
|
|
|
if callback.overflow_list[i] == Tensor(True, mstype.bool_) and i > 0: |
|
|
|
count = 0 |
|
|
|
assert callback.lossscale_list[i] == callback.lossscale_list[i - 1] * Tensor(0.5, mstype.float32) |
|
|
|
if callback.overflow_list[i] == Tensor(False, mstype.bool_): |
|
|
|
count = count + 1 |
|
|
|
if count == scale_window: |
|
|
|
count = 0 |
|
|
|
assert callback.lossscale_list[i] == callback.lossscale_list[i - 1] * Tensor(2.0, mstype.float32) |
|
|
|
# assertion occurs while the loss value is wrong |
|
|
|
# assertion occurs while the loss value, overflow state or loss_scale value is wrong |
|
|
|
loss_value = np.array(callback.loss_list) |
|
|
|
expect_value = [12.1918125, 11.966035, 11.972114, 11.982671, 11.976399, 12.616986, 12.180658, 12.850562, 12.415608, 12.640145] |
|
|
|
expect_loss_value = [12.1918125, 11.966035, 11.972114, 11.982188, 11.974092, 12.610916, 12.17565, 12.840416, 12.40291, 12.621661] |
|
|
|
print("loss value: {}".format(loss_value)) |
|
|
|
assert np.allclose(loss_value, expect_value, 0.00001, 0.00001) |
|
|
|
assert np.allclose(loss_value, expect_loss_value, 0.00001, 0.00001) |
|
|
|
|
|
|
|
overflow = np.array(callback.overflow_list) |
|
|
|
expect_overflow = [True, True, False, False, False, True, False, False, False, True] |
|
|
|
print("overflow: {}".format(overflow)) |
|
|
|
assert (overflow == expect_overflow).all() |
|
|
|
|
|
|
|
loss_scale = np.array(callback.lossscale_list) |
|
|
|
expect_loss_scale = [32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0] |
|
|
|
print("loss scale: {}".format(loss_scale)) |
|
|
|
assert np.allclose(loss_scale, expect_loss_scale, 0.00001, 0.00001) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
test_bert_tdt() |