|
|
|
@@ -170,8 +170,14 @@ def run_transformer_train(): |
|
|
|
|
|
|
|
netwithgrads.set_train(True) |
|
|
|
model = Model(netwithgrads) |
|
|
|
model.train(args.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true"), |
|
|
|
sink_size=args.save_checkpoint_steps) |
|
|
|
|
|
|
|
enable_sink = (args.enable_data_sink == "true") |
|
|
|
if enable_sink: |
|
|
|
sink_size = args.save_checkpoint_steps |
|
|
|
model.train(args.epoch_size*dataset.get_dataset_size()//sink_size, dataset, callbacks=callbacks, |
|
|
|
dataset_sink_mode=enable_sink, sink_size=sink_size) |
|
|
|
else: |
|
|
|
model.train(args.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=enable_sink) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
run_transformer_train() |