| @@ -40,13 +40,21 @@ def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0, | |||||
| else: | else: | ||||
| columns_list = ["input_ids", "input_mask", "segment_ids"] | columns_list = ["input_ids", "input_mask", "segment_ids"] | ||||
| shard_equal_rows = True | |||||
| shuffle = (do_shuffle == "true") | |||||
| if device_num == 1: | |||||
| shard_equal_rows = False | |||||
| shuffle = False | |||||
| if data_type == DataType.MINDRECORD: | if data_type == DataType.MINDRECORD: | ||||
| ds = de.MindDataset(data_files, columns_list=columns_list, | ds = de.MindDataset(data_files, columns_list=columns_list, | ||||
| shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank) | shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank) | ||||
| else: | else: | ||||
| ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list, | ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list, | ||||
| shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, | |||||
| shard_equal_rows=True) | |||||
| shuffle=shuffle, num_shards=device_num, shard_id=rank, | |||||
| shard_equal_rows=shard_equal_rows) | |||||
| if device_num == 1 and shuffle is True: | |||||
| ds = ds.shuffle(10000) | |||||
| type_cast_op = C.TypeCast(mstype.int32) | type_cast_op = C.TypeCast(mstype.int32) | ||||
| ds = ds.map(operations=type_cast_op, input_columns="segment_ids") | ds = ds.map(operations=type_cast_op, input_columns="segment_ids") | ||||