Browse Source

fix tinybert failes if run 1p

tags/v1.1.0
yanghaitao1 5 years ago
parent
commit
a4de9ba0eb
1 changed files with 10 additions and 2 deletions
  1. +10
    -2
      model_zoo/official/nlp/tinybert/src/dataset.py

+ 10
- 2
model_zoo/official/nlp/tinybert/src/dataset.py View File

@@ -40,13 +40,21 @@ def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0,
else: else:
columns_list = ["input_ids", "input_mask", "segment_ids"] columns_list = ["input_ids", "input_mask", "segment_ids"]


shard_equal_rows = True
shuffle = (do_shuffle == "true")
if device_num == 1:
shard_equal_rows = False
shuffle = False

if data_type == DataType.MINDRECORD: if data_type == DataType.MINDRECORD:
ds = de.MindDataset(data_files, columns_list=columns_list, ds = de.MindDataset(data_files, columns_list=columns_list,
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank) shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank)
else: else:
ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list, ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list,
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
shard_equal_rows=True)
shuffle=shuffle, num_shards=device_num, shard_id=rank,
shard_equal_rows=shard_equal_rows)
if device_num == 1 and shuffle is True:
ds = ds.shuffle(10000)


type_cast_op = C.TypeCast(mstype.int32) type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(operations=type_cast_op, input_columns="segment_ids") ds = ds.map(operations=type_cast_op, input_columns="segment_ids")


Loading…
Cancel
Save