|
|
|
@@ -17,10 +17,10 @@ |
|
|
|
import mindspore.common.dtype as mstype |
|
|
|
import mindspore.dataset as de |
|
|
|
import mindspore.dataset.transforms.c_transforms as deC |
|
|
|
from .config import transformer_net_cfg |
|
|
|
from .config import transformer_net_cfg, transformer_net_cfg_gpu |
|
|
|
de.config.set_seed(1) |
|
|
|
def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle="true", dataset_path=None, |
|
|
|
bucket_boundaries=None): |
|
|
|
bucket_boundaries=None, device_target="Ascend"): |
|
|
|
"""create dataset""" |
|
|
|
def batch_per_bucket(bucket_len, dataset_path): |
|
|
|
dataset_path = dataset_path + "_" + str(bucket_len) + "_00" |
|
|
|
@@ -38,7 +38,11 @@ def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle |
|
|
|
ds = ds.map(operations=type_cast_op, input_columns="target_eos_mask") |
|
|
|
|
|
|
|
# apply batch operations |
|
|
|
ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True) |
|
|
|
if device_target == "Ascend": |
|
|
|
ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True) |
|
|
|
else: |
|
|
|
ds = ds.batch(transformer_net_cfg_gpu.batch_size, drop_remainder=True) |
|
|
|
|
|
|
|
ds = ds.repeat(epoch_count) |
|
|
|
return ds |
|
|
|
|
|
|
|
|