|
|
|
@@ -39,6 +39,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e |
|
|
|
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, |
|
|
|
shard_equal_rows=True) |
|
|
|
ori_dataset_size = ds.get_dataset_size() |
|
|
|
print('origin dataset size: ', ori_dataset_size) |
|
|
|
new_size = ori_dataset_size |
|
|
|
if enable_data_sink == "true": |
|
|
|
new_size = data_sink_steps * bert_net_cfg.batch_size |
|
|
|
@@ -53,7 +54,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e |
|
|
|
ds = ds.map(input_columns="input_ids", operations=type_cast_op) |
|
|
|
# apply batch operations |
|
|
|
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True) |
|
|
|
ds = ds.repeat(new_repeat_count) |
|
|
|
ds = ds.repeat(max(new_repeat_count, repeat_count)) |
|
|
|
logger.info("data size: {}".format(ds.get_dataset_size())) |
|
|
|
logger.info("repeatcount: {}".format(ds.get_repeat_count())) |
|
|
|
return ds, new_repeat_count |