|
|
|
@@ -40,13 +40,21 @@ def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0, |
|
|
|
else: |
|
|
|
columns_list = ["input_ids", "input_mask", "segment_ids"] |
|
|
|
|
|
|
|
shard_equal_rows = True |
|
|
|
shuffle = (do_shuffle == "true") |
|
|
|
if device_num == 1: |
|
|
|
shard_equal_rows = False |
|
|
|
shuffle = False |
|
|
|
|
|
|
|
if data_type == DataType.MINDRECORD: |
|
|
|
ds = de.MindDataset(data_files, columns_list=columns_list, |
|
|
|
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank) |
|
|
|
else: |
|
|
|
ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list, |
|
|
|
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, |
|
|
|
shard_equal_rows=True) |
|
|
|
shuffle=shuffle, num_shards=device_num, shard_id=rank, |
|
|
|
shard_equal_rows=shard_equal_rows) |
|
|
|
if device_num == 1 and shuffle is True: |
|
|
|
ds = ds.shuffle(10000) |
|
|
|
|
|
|
|
type_cast_op = C.TypeCast(mstype.int32) |
|
|
|
ds = ds.map(operations=type_cast_op, input_columns="segment_ids") |
|
|
|
|