You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

benchmark.py 6.4 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import copy
  4. import os
  5. import time
  6. import torch
  7. from mmcv import Config, DictAction
  8. from mmcv.cnn import fuse_conv_bn
  9. from mmcv.parallel import MMDistributedDataParallel
  10. from mmcv.runner import init_dist, load_checkpoint, wrap_fp16_model
  11. from mmdet.datasets import (build_dataloader, build_dataset,
  12. replace_ImageToTensor)
  13. from mmdet.models import build_detector
  14. def parse_args():
  15. parser = argparse.ArgumentParser(description='MMDet benchmark a model')
  16. parser.add_argument('config', help='test config file path')
  17. parser.add_argument('checkpoint', help='checkpoint file')
  18. parser.add_argument(
  19. '--repeat-num',
  20. type=int,
  21. default=1,
  22. help='number of repeat times of measurement for averaging the results')
  23. parser.add_argument(
  24. '--max-iter', type=int, default=2000, help='num of max iter')
  25. parser.add_argument(
  26. '--log-interval', type=int, default=50, help='interval of logging')
  27. parser.add_argument(
  28. '--fuse-conv-bn',
  29. action='store_true',
  30. help='Whether to fuse conv and bn, this will slightly increase'
  31. 'the inference speed')
  32. parser.add_argument(
  33. '--cfg-options',
  34. nargs='+',
  35. action=DictAction,
  36. help='override some settings in the used config, the key-value pair '
  37. 'in xxx=yyy format will be merged into config file. If the value to '
  38. 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
  39. 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
  40. 'Note that the quotation marks are necessary and that no white space '
  41. 'is allowed.')
  42. parser.add_argument(
  43. '--launcher',
  44. choices=['none', 'pytorch', 'slurm', 'mpi'],
  45. default='none',
  46. help='job launcher')
  47. parser.add_argument('--local_rank', type=int, default=0)
  48. args = parser.parse_args()
  49. if 'LOCAL_RANK' not in os.environ:
  50. os.environ['LOCAL_RANK'] = str(args.local_rank)
  51. return args
  52. def measure_inference_speed(cfg, checkpoint, max_iter, log_interval,
  53. is_fuse_conv_bn):
  54. # set cudnn_benchmark
  55. if cfg.get('cudnn_benchmark', False):
  56. torch.backends.cudnn.benchmark = True
  57. cfg.model.pretrained = None
  58. cfg.data.test.test_mode = True
  59. # build the dataloader
  60. samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
  61. if samples_per_gpu > 1:
  62. # Replace 'ImageToTensor' to 'DefaultFormatBundle'
  63. cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
  64. dataset = build_dataset(cfg.data.test)
  65. data_loader = build_dataloader(
  66. dataset,
  67. samples_per_gpu=1,
  68. # Because multiple processes will occupy additional CPU resources,
  69. # FPS statistics will be more unstable when workers_per_gpu is not 0.
  70. # It is reasonable to set workers_per_gpu to 0.
  71. workers_per_gpu=0,
  72. dist=True,
  73. shuffle=False)
  74. # build the model and load checkpoint
  75. cfg.model.train_cfg = None
  76. model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
  77. fp16_cfg = cfg.get('fp16', None)
  78. if fp16_cfg is not None:
  79. wrap_fp16_model(model)
  80. load_checkpoint(model, checkpoint, map_location='cpu')
  81. if is_fuse_conv_bn:
  82. model = fuse_conv_bn(model)
  83. model = MMDistributedDataParallel(
  84. model.cuda(),
  85. device_ids=[torch.cuda.current_device()],
  86. broadcast_buffers=False)
  87. model.eval()
  88. # the first several iterations may be very slow so skip them
  89. num_warmup = 5
  90. pure_inf_time = 0
  91. fps = 0
  92. # benchmark with 2000 image and take the average
  93. for i, data in enumerate(data_loader):
  94. torch.cuda.synchronize()
  95. start_time = time.perf_counter()
  96. with torch.no_grad():
  97. model(return_loss=False, rescale=True, **data)
  98. torch.cuda.synchronize()
  99. elapsed = time.perf_counter() - start_time
  100. if i >= num_warmup:
  101. pure_inf_time += elapsed
  102. if (i + 1) % log_interval == 0:
  103. fps = (i + 1 - num_warmup) / pure_inf_time
  104. print(
  105. f'Done image [{i + 1:<3}/ {max_iter}], '
  106. f'fps: {fps:.1f} img / s, '
  107. f'times per image: {1000 / fps:.1f} ms / img',
  108. flush=True)
  109. if (i + 1) == max_iter:
  110. fps = (i + 1 - num_warmup) / pure_inf_time
  111. print(
  112. f'Overall fps: {fps:.1f} img / s, '
  113. f'times per image: {1000 / fps:.1f} ms / img',
  114. flush=True)
  115. break
  116. return fps
  117. def repeat_measure_inference_speed(cfg,
  118. checkpoint,
  119. max_iter,
  120. log_interval,
  121. is_fuse_conv_bn,
  122. repeat_num=1):
  123. assert repeat_num >= 1
  124. fps_list = []
  125. for _ in range(repeat_num):
  126. #
  127. cp_cfg = copy.deepcopy(cfg)
  128. fps_list.append(
  129. measure_inference_speed(cp_cfg, checkpoint, max_iter, log_interval,
  130. is_fuse_conv_bn))
  131. if repeat_num > 1:
  132. fps_list_ = [round(fps, 1) for fps in fps_list]
  133. times_pre_image_list_ = [round(1000 / fps, 1) for fps in fps_list]
  134. mean_fps_ = sum(fps_list_) / len(fps_list_)
  135. mean_times_pre_image_ = sum(times_pre_image_list_) / len(
  136. times_pre_image_list_)
  137. print(
  138. f'Overall fps: {fps_list_}[{mean_fps_:.1f}] img / s, '
  139. f'times per image: '
  140. f'{times_pre_image_list_}[{mean_times_pre_image_:.1f}] ms / img',
  141. flush=True)
  142. return fps_list
  143. return fps_list[0]
  144. def main():
  145. args = parse_args()
  146. cfg = Config.fromfile(args.config)
  147. if args.cfg_options is not None:
  148. cfg.merge_from_dict(args.cfg_options)
  149. if args.launcher == 'none':
  150. raise NotImplementedError('Only supports distributed mode')
  151. else:
  152. init_dist(args.launcher, **cfg.dist_params)
  153. repeat_measure_inference_speed(cfg, args.checkpoint, args.max_iter,
  154. args.log_interval, args.fuse_conv_bn,
  155. args.repeat_num)
  156. if __name__ == '__main__':
  157. main()

No Description

Contributors (1)