From ff50625ea3e59404ca8a5fdeb6c3d1130b167990 Mon Sep 17 00:00:00 2001 From: gengdongjie Date: Thu, 14 May 2020 21:12:27 +0800 Subject: [PATCH] add o2 amp level to resent50_imagenet2012 --- example/resnet50_imagenet2012/train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/example/resnet50_imagenet2012/train.py b/example/resnet50_imagenet2012/train.py index 1992bfda95..d87189f8a7 100755 --- a/example/resnet50_imagenet2012/train.py +++ b/example/resnet50_imagenet2012/train.py @@ -86,7 +86,9 @@ if __name__ == '__main__': opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay, config.loss_scale) - model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}) + model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O2", + keep_batchnorm_fp32=False) + time_cb = TimeMonitor(data_size=step_size) loss_cb = LossMonitor()