|
|
|
@@ -163,7 +163,7 @@ if __name__ == '__main__': |
|
|
|
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", is_grad=False, |
|
|
|
num_classes=config.class_num) |
|
|
|
|
|
|
|
if args_opt.net == "resnet101": |
|
|
|
if args_opt.net == "resnet101" or args_opt.net == "resnet50": |
|
|
|
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay, |
|
|
|
config.loss_scale) |
|
|
|
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) |
|
|
|
|