|
|
|
@@ -111,7 +111,7 @@ if __name__ == '__main__': |
|
|
|
if args_opt.net == "resnet50" or args_opt.net == "se-resnet50": |
|
|
|
context.set_auto_parallel_context(all_reduce_fusion_config=[85, 160]) |
|
|
|
elif args_opt.net == "resnet101": |
|
|
|
context.set_auto_parallel_context(all_reduce_fusion_config=[180, 313]) |
|
|
|
context.set_auto_parallel_context(all_reduce_fusion_config=[80, 210, 313]) |
|
|
|
init() |
|
|
|
# GPU target |
|
|
|
else: |
|
|
|
|