|
|
|
@@ -36,8 +36,8 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e |
|
|
|
ds = de.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None, |
|
|
|
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", |
|
|
|
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"], |
|
|
|
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, |
|
|
|
shard_equal_rows=True) |
|
|
|
shuffle=de.Shuffle.FILES if do_shuffle == "true" else False, |
|
|
|
num_shards=device_num, shard_id=rank, shard_equal_rows=True) |
|
|
|
ori_dataset_size = ds.get_dataset_size() |
|
|
|
print('origin dataset size: ', ori_dataset_size) |
|
|
|
new_size = ori_dataset_size |
|
|
|
|