| @@ -54,13 +54,12 @@ def load_test_data(batch_size=1): | |||||
| return ret | return ret | ||||
| def get_config(version='base', batch_size=1): | |||||
| def get_config(version='base'): | |||||
| """ | """ | ||||
| get_config definition | get_config definition | ||||
| """ | """ | ||||
| if version == 'base': | if version == 'base': | ||||
| return BertConfig( | return BertConfig( | ||||
| batch_size=batch_size, | |||||
| seq_length=128, | seq_length=128, | ||||
| vocab_size=21128, | vocab_size=21128, | ||||
| hidden_size=768, | hidden_size=768, | ||||
| @@ -74,13 +73,10 @@ def get_config(version='base', batch_size=1): | |||||
| type_vocab_size=2, | type_vocab_size=2, | ||||
| initializer_range=0.02, | initializer_range=0.02, | ||||
| use_relative_positions=True, | use_relative_positions=True, | ||||
| input_mask_from_dataset=True, | |||||
| token_type_ids_from_dataset=True, | |||||
| dtype=mstype.float32, | dtype=mstype.float32, | ||||
| compute_type=mstype.float32) | compute_type=mstype.float32) | ||||
| if version == 'large': | if version == 'large': | ||||
| return BertConfig( | return BertConfig( | ||||
| batch_size=batch_size, | |||||
| seq_length=128, | seq_length=128, | ||||
| vocab_size=21128, | vocab_size=21128, | ||||
| hidden_size=1024, | hidden_size=1024, | ||||
| @@ -94,11 +90,9 @@ def get_config(version='base', batch_size=1): | |||||
| type_vocab_size=2, | type_vocab_size=2, | ||||
| initializer_range=0.02, | initializer_range=0.02, | ||||
| use_relative_positions=True, | use_relative_positions=True, | ||||
| input_mask_from_dataset=True, | |||||
| token_type_ids_from_dataset=True, | |||||
| dtype=mstype.float32, | dtype=mstype.float32, | ||||
| compute_type=mstype.float32) | compute_type=mstype.float32) | ||||
| return BertConfig(batch_size=batch_size) | |||||
| return BertConfig() | |||||
| class BertLearningRate(lr_schedules.LearningRateSchedule): | class BertLearningRate(lr_schedules.LearningRateSchedule): | ||||
| @@ -143,7 +137,7 @@ def test_bert_train(): | |||||
| batch_size = int(os.getenv('BATCH_SIZE', '1')) | batch_size = int(os.getenv('BATCH_SIZE', '1')) | ||||
| inputs = load_test_data(batch_size) | inputs = load_test_data(batch_size) | ||||
| config = get_config(version=version, batch_size=batch_size) | |||||
| config = get_config(version=version) | |||||
| netwithloss = BertNetworkWithLoss(config, True) | netwithloss = BertNetworkWithLoss(config, True) | ||||
| lr = BertLearningRate(10) | lr = BertLearningRate(10) | ||||
| optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) | optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) | ||||
| @@ -168,7 +162,7 @@ def test_bert_withlossscale_train(): | |||||
| scaling_sens = Tensor(np.ones([1]).astype(np.float32)) | scaling_sens = Tensor(np.ones([1]).astype(np.float32)) | ||||
| inputs = load_test_data(batch_size) + (scaling_sens,) | inputs = load_test_data(batch_size) + (scaling_sens,) | ||||
| config = get_config(version=version, batch_size=batch_size) | |||||
| config = get_config(version=version) | |||||
| netwithloss = BertNetworkWithLoss(config, True) | netwithloss = BertNetworkWithLoss(config, True) | ||||
| lr = BertLearningRate(10) | lr = BertLearningRate(10) | ||||
| optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) | optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) | ||||
| @@ -195,7 +189,7 @@ def bert_withlossscale_manager_train(): | |||||
| batch_size = int(os.getenv('BATCH_SIZE', '1')) | batch_size = int(os.getenv('BATCH_SIZE', '1')) | ||||
| inputs = load_test_data(batch_size) | inputs = load_test_data(batch_size) | ||||
| config = get_config(version=version, batch_size=batch_size) | |||||
| config = get_config(version=version) | |||||
| netwithloss = BertNetworkWithLoss(config, True) | netwithloss = BertNetworkWithLoss(config, True) | ||||
| lr = BertLearningRate(10) | lr = BertLearningRate(10) | ||||
| optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) | optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) | ||||
| @@ -223,7 +217,7 @@ def bert_withlossscale_manager_train_feed(): | |||||
| scaling_sens = Tensor(np.ones([1]).astype(np.float32)) | scaling_sens = Tensor(np.ones([1]).astype(np.float32)) | ||||
| inputs = load_test_data(batch_size) + (scaling_sens,) | inputs = load_test_data(batch_size) + (scaling_sens,) | ||||
| config = get_config(version=version, batch_size=batch_size) | |||||
| config = get_config(version=version) | |||||
| netwithloss = BertNetworkWithLoss(config, True) | netwithloss = BertNetworkWithLoss(config, True) | ||||
| lr = BertLearningRate(10) | lr = BertLearningRate(10) | ||||
| optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) | optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) | ||||