|
|
|
@@ -205,5 +205,9 @@ if __name__ == '__main__': |
|
|
|
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_" + args_opt.dataset_name, directory=ckpt_save_dir, |
|
|
|
config=config_ck) |
|
|
|
loss_cb = LossMonitor() |
|
|
|
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) |
|
|
|
|
|
|
|
cbs = [time_cb, ckpoint_cb, loss_cb] |
|
|
|
if device_num > 1 and rank != 0: |
|
|
|
cbs = [time_cb, loss_cb] |
|
|
|
model.train(cfg.epoch_size, dataset, callbacks=cbs) |
|
|
|
print("train success") |