Browse Source

fix sink_size bug for transformer

tags/v1.0.0
yuchaojie 5 years ago
parent
commit
9fb6f0c34b
1 changed files with 8 additions and 2 deletions
  1. +8
    -2
      model_zoo/official/nlp/transformer/train.py

+ 8
- 2
model_zoo/official/nlp/transformer/train.py View File

@@ -170,8 +170,14 @@ def run_transformer_train():

netwithgrads.set_train(True)
model = Model(netwithgrads)
model.train(args.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true"),
sink_size=args.save_checkpoint_steps)

enable_sink = (args.enable_data_sink == "true")
if enable_sink:
sink_size = args.save_checkpoint_steps
model.train(args.epoch_size*dataset.get_dataset_size()//sink_size, dataset, callbacks=callbacks,
dataset_sink_mode=enable_sink, sink_size=sink_size)
else:
model.train(args.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=enable_sink)

if __name__ == '__main__':
run_transformer_train()

Loading…
Cancel
Save