|
|
|
@@ -207,7 +207,7 @@ if __name__ == '__main__': |
|
|
|
metrics = {"acc"} |
|
|
|
if args_opt.run_distribute: |
|
|
|
metrics = {'acc': DistAccuracy(batch_size=config.batch_size, device_num=args_opt.device_num)} |
|
|
|
if (args_opt.net != "resnet101" and args_opt.net != "resnet50") or \ |
|
|
|
if (args_opt.net not in ("resnet18", "resnet50", "resnet101")) or \ |
|
|
|
args_opt.parameter_server or target == "CPU": |
|
|
|
## fp32 training |
|
|
|
model = Model(net, loss_fn=loss, optimizer=opt, metrics=metrics, eval_network=dist_eval_network) |
|
|
|
|