Browse Source

update trantransformer scripts

tags/v1.1.0
panfengfeng 5 years ago
parent
commit
752035795f
2 changed files with 9 additions and 4 deletions
  1. +7
    -3
      model_zoo/official/nlp/transformer/src/dataset.py
  2. +2
    -1
      model_zoo/official/nlp/transformer/train.py

+ 7
- 3
model_zoo/official/nlp/transformer/src/dataset.py View File

@@ -17,10 +17,10 @@
import mindspore.common.dtype as mstype
import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as deC
from .config import transformer_net_cfg
from .config import transformer_net_cfg, transformer_net_cfg_gpu
de.config.set_seed(1)
def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle="true", dataset_path=None,
bucket_boundaries=None):
bucket_boundaries=None, device_target="Ascend"):
"""create dataset"""
def batch_per_bucket(bucket_len, dataset_path):
dataset_path = dataset_path + "_" + str(bucket_len) + "_00"
@@ -38,7 +38,11 @@ def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle
ds = ds.map(operations=type_cast_op, input_columns="target_eos_mask")

# apply batch operations
ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True)
if device_target == "Ascend":
ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True)
else:
ds = ds.batch(transformer_net_cfg_gpu.batch_size, drop_remainder=True)

ds = ds.repeat(epoch_count)
return ds



+ 2
- 1
model_zoo/official/nlp/transformer/train.py View File

@@ -146,7 +146,8 @@ def run_transformer_train():
dataset = create_transformer_dataset(epoch_count=1, rank_size=device_num,
rank_id=rank_id, do_shuffle=args.do_shuffle,
dataset_path=args.data_path,
bucket_boundaries=args.bucket_boundaries)
bucket_boundaries=args.bucket_boundaries,
device_target=args.device_target)
if args.device_target == "Ascend":
netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True)
else:


Loading…
Cancel
Save