You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

benchmark.py 4.5 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. """
  3. A script to benchmark builtin models.
  4. Note: this script has an extra dependency of psutil.
  5. """
  6. import itertools
  7. import logging
  8. import psutil
  9. import torch
  10. import tqdm
  11. from fvcore.common.timer import Timer
  12. from torch.nn.parallel import DistributedDataParallel
  13. from detectron2.checkpoint import DetectionCheckpointer
  14. from detectron2.config import get_cfg
  15. from detectron2.data import (
  16. DatasetFromList,
  17. build_detection_test_loader,
  18. build_detection_train_loader,
  19. )
  20. from detectron2.engine import SimpleTrainer, default_argument_parser, hooks, launch
  21. from detectron2.modeling import build_model
  22. from detectron2.solver import build_optimizer
  23. from detectron2.utils import comm
  24. from detectron2.utils.events import CommonMetricPrinter
  25. from detectron2.utils.logger import setup_logger
  26. logger = logging.getLogger("detectron2")
  27. def setup(args):
  28. cfg = get_cfg()
  29. cfg.merge_from_file(args.config_file)
  30. cfg.SOLVER.BASE_LR = 0.001 # Avoid NaNs. Not useful in this script anyway.
  31. cfg.merge_from_list(args.opts)
  32. cfg.freeze()
  33. setup_logger(distributed_rank=comm.get_rank())
  34. return cfg
  35. def benchmark_data(args):
  36. cfg = setup(args)
  37. dataloader = build_detection_train_loader(cfg)
  38. timer = Timer()
  39. itr = iter(dataloader)
  40. for i in range(10): # warmup
  41. next(itr)
  42. if i == 0:
  43. startup_time = timer.seconds()
  44. timer = Timer()
  45. max_iter = 1000
  46. for _ in tqdm.trange(max_iter):
  47. next(itr)
  48. logger.info(
  49. "{} iters ({} images) in {} seconds.".format(
  50. max_iter, max_iter * cfg.SOLVER.IMS_PER_BATCH, timer.seconds()
  51. )
  52. )
  53. logger.info("Startup time: {} seconds".format(startup_time))
  54. vram = psutil.virtual_memory()
  55. logger.info(
  56. "RAM Usage: {:.2f}/{:.2f} GB".format(
  57. (vram.total - vram.available) / 1024 ** 3, vram.total / 1024 ** 3
  58. )
  59. )
  60. def benchmark_train(args):
  61. cfg = setup(args)
  62. model = build_model(cfg)
  63. logger.info("Model:\n{}".format(model))
  64. if comm.get_world_size() > 1:
  65. model = DistributedDataParallel(
  66. model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
  67. )
  68. optimizer = build_optimizer(cfg, model)
  69. checkpointer = DetectionCheckpointer(model, optimizer=optimizer)
  70. checkpointer.load(cfg.MODEL.WEIGHTS)
  71. cfg.defrost()
  72. cfg.DATALOADER.NUM_WORKERS = 0
  73. data_loader = build_detection_train_loader(cfg)
  74. dummy_data = list(itertools.islice(data_loader, 100))
  75. def f():
  76. while True:
  77. yield from DatasetFromList(dummy_data, copy=False)
  78. max_iter = 400
  79. trainer = SimpleTrainer(model, f(), optimizer)
  80. trainer.register_hooks(
  81. [hooks.IterationTimer(), hooks.PeriodicWriter([CommonMetricPrinter(max_iter)])]
  82. )
  83. trainer.train(1, max_iter)
  84. @torch.no_grad()
  85. def benchmark_eval(args):
  86. cfg = setup(args)
  87. model = build_model(cfg)
  88. model.eval()
  89. logger.info("Model:\n{}".format(model))
  90. DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)
  91. cfg.defrost()
  92. cfg.DATALOADER.NUM_WORKERS = 0
  93. data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0])
  94. dummy_data = list(itertools.islice(data_loader, 100))
  95. def f():
  96. while True:
  97. yield from DatasetFromList(dummy_data, copy=False)
  98. for _ in range(5): # warmup
  99. model(dummy_data[0])
  100. max_iter = 400
  101. timer = Timer()
  102. with tqdm.tqdm(total=max_iter) as pbar:
  103. for idx, d in enumerate(f()):
  104. if idx == max_iter:
  105. break
  106. model(d)
  107. pbar.update()
  108. logger.info("{} iters in {} seconds.".format(max_iter, timer.seconds()))
  109. if __name__ == "__main__":
  110. parser = default_argument_parser()
  111. parser.add_argument("--task", choices=["train", "eval", "data"], required=True)
  112. args = parser.parse_args()
  113. assert not args.eval_only
  114. if args.task == "data":
  115. f = benchmark_data
  116. elif args.task == "train":
  117. """
  118. Note: training speed may not be representative.
  119. The training cost of a R-CNN model varies with the content of the data
  120. and the quality of the model.
  121. """
  122. f = benchmark_train
  123. elif args.task == "eval":
  124. f = benchmark_eval
  125. # only benchmark single-GPU inference.
  126. assert args.num_gpus == 1 and args.num_machines == 1
  127. launch(f, args.num_gpus, args.num_machines, args.machine_rank, args.dist_url, args=(args,))

No Description