| @@ -54,13 +54,12 @@ def load_test_data(batch_size=1): | |||
| return ret | |||
| def get_config(version='base', batch_size=1): | |||
| def get_config(version='base'): | |||
| """ | |||
| get_config definition | |||
| """ | |||
| if version == 'base': | |||
| return BertConfig( | |||
| batch_size=batch_size, | |||
| seq_length=128, | |||
| vocab_size=21128, | |||
| hidden_size=768, | |||
| @@ -74,13 +73,10 @@ def get_config(version='base', batch_size=1): | |||
| type_vocab_size=2, | |||
| initializer_range=0.02, | |||
| use_relative_positions=True, | |||
| input_mask_from_dataset=True, | |||
| token_type_ids_from_dataset=True, | |||
| dtype=mstype.float32, | |||
| compute_type=mstype.float32) | |||
| if version == 'large': | |||
| return BertConfig( | |||
| batch_size=batch_size, | |||
| seq_length=128, | |||
| vocab_size=21128, | |||
| hidden_size=1024, | |||
| @@ -94,11 +90,9 @@ def get_config(version='base', batch_size=1): | |||
| type_vocab_size=2, | |||
| initializer_range=0.02, | |||
| use_relative_positions=True, | |||
| input_mask_from_dataset=True, | |||
| token_type_ids_from_dataset=True, | |||
| dtype=mstype.float32, | |||
| compute_type=mstype.float32) | |||
| return BertConfig(batch_size=batch_size) | |||
| return BertConfig() | |||
| class BertLearningRate(lr_schedules.LearningRateSchedule): | |||
| @@ -143,7 +137,7 @@ def test_bert_train(): | |||
| batch_size = int(os.getenv('BATCH_SIZE', '1')) | |||
| inputs = load_test_data(batch_size) | |||
| config = get_config(version=version, batch_size=batch_size) | |||
| config = get_config(version=version) | |||
| netwithloss = BertNetworkWithLoss(config, True) | |||
| lr = BertLearningRate(10) | |||
| optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) | |||
| @@ -168,7 +162,7 @@ def test_bert_withlossscale_train(): | |||
| scaling_sens = Tensor(np.ones([1]).astype(np.float32)) | |||
| inputs = load_test_data(batch_size) + (scaling_sens,) | |||
| config = get_config(version=version, batch_size=batch_size) | |||
| config = get_config(version=version) | |||
| netwithloss = BertNetworkWithLoss(config, True) | |||
| lr = BertLearningRate(10) | |||
| optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) | |||
| @@ -195,7 +189,7 @@ def bert_withlossscale_manager_train(): | |||
| batch_size = int(os.getenv('BATCH_SIZE', '1')) | |||
| inputs = load_test_data(batch_size) | |||
| config = get_config(version=version, batch_size=batch_size) | |||
| config = get_config(version=version) | |||
| netwithloss = BertNetworkWithLoss(config, True) | |||
| lr = BertLearningRate(10) | |||
| optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) | |||
| @@ -223,7 +217,7 @@ def bert_withlossscale_manager_train_feed(): | |||
| scaling_sens = Tensor(np.ones([1]).astype(np.float32)) | |||
| inputs = load_test_data(batch_size) + (scaling_sens,) | |||
| config = get_config(version=version, batch_size=batch_size) | |||
| config = get_config(version=version) | |||
| netwithloss = BertNetworkWithLoss(config, True) | |||
| lr = BertLearningRate(10) | |||
| optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) | |||