Browse Source

add o2 amp level to resent50_imagenet2012

tags/v0.3.0-alpha
gengdongjie 5 years ago
parent
commit
ff50625ea3
1 changed files with 3 additions and 1 deletions
  1. +3
    -1
      example/resnet50_imagenet2012/train.py

+ 3
- 1
example/resnet50_imagenet2012/train.py View File

@@ -86,7 +86,9 @@ if __name__ == '__main__':
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
config.weight_decay, config.loss_scale)

model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'})
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O2",
keep_batchnorm_fp32=False)


time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor()


Loading…
Cancel
Save