|
|
|
@@ -146,10 +146,14 @@ def main(): |
|
|
|
loss_scale_manager = FixedLossScaleManager( |
|
|
|
cfg.loss_scale, drop_overflow_update=False) |
|
|
|
|
|
|
|
config_ck = CheckpointConfig( |
|
|
|
save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max) |
|
|
|
ckpoint_cb = ModelCheckpoint( |
|
|
|
prefix=cfg.model, directory=output_dir, config=config_ck) |
|
|
|
callbacks = [time_cb, loss_cb] |
|
|
|
|
|
|
|
if cfg.save_checkpoint: |
|
|
|
config_ck = CheckpointConfig( |
|
|
|
save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max) |
|
|
|
ckpoint_cb = ModelCheckpoint( |
|
|
|
prefix=cfg.model, directory=output_dir, config=config_ck) |
|
|
|
callbacks += [ckpoint_cb] |
|
|
|
|
|
|
|
lr = Tensor(get_lr(base_lr=cfg.lr, total_epochs=cfg.epochs, steps_per_epoch=batches_per_epoch, |
|
|
|
decay_steps=cfg.decay_epochs, decay_rate=cfg.decay_rate, |
|
|
|
@@ -176,7 +180,7 @@ def main(): |
|
|
|
amp_level=cfg.amp_level |
|
|
|
) |
|
|
|
|
|
|
|
callbacks = [loss_cb, ckpoint_cb, time_cb] if is_master else [] |
|
|
|
callbacks = callbacks if is_master else [] |
|
|
|
|
|
|
|
if args.resume: |
|
|
|
real_epoch = cfg.epochs - cfg.resume_start_epoch |
|
|
|
|