|
|
|
@@ -77,7 +77,8 @@ if __name__ == '__main__': |
|
|
|
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, |
|
|
|
config.weight_decay, config.loss_scale) |
|
|
|
|
|
|
|
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}) |
|
|
|
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O2", |
|
|
|
keep_batchnorm_fp32=False) |
|
|
|
|
|
|
|
time_cb = TimeMonitor(data_size=step_size) |
|
|
|
loss_cb = LossMonitor() |
|
|
|
|