Browse Source

fix performance jitter for googlenet train

pull/15160/head
CaoJian 4 years ago
parent
commit
de4828a71f
1 changed files with 5 additions and 1 deletions
  1. +5
    -1
      model_zoo/official/cv/googlenet/train.py

+ 5
- 1
model_zoo/official/cv/googlenet/train.py View File

@@ -205,5 +205,9 @@ if __name__ == '__main__':
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_" + args_opt.dataset_name, directory=ckpt_save_dir,
config=config_ck)
loss_cb = LossMonitor()
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])

cbs = [time_cb, ckpoint_cb, loss_cb]
if device_num > 1 and rank != 0:
cbs = [time_cb, loss_cb]
model.train(cfg.epoch_size, dataset, callbacks=cbs)
print("train success")

Loading…
Cancel
Save