You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

get_flops.py 3.2 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import numpy as np
  4. import torch
  5. from mmcv import Config, DictAction
  6. from mmdet.models import build_detector
  7. try:
  8. from mmcv.cnn import get_model_complexity_info
  9. except ImportError:
  10. raise ImportError('Please upgrade mmcv to >0.6.2')
  11. def parse_args():
  12. parser = argparse.ArgumentParser(description='Train a detector')
  13. parser.add_argument('config', help='train config file path')
  14. parser.add_argument(
  15. '--shape',
  16. type=int,
  17. nargs='+',
  18. default=[1280, 800],
  19. help='input image size')
  20. parser.add_argument(
  21. '--cfg-options',
  22. nargs='+',
  23. action=DictAction,
  24. help='override some settings in the used config, the key-value pair '
  25. 'in xxx=yyy format will be merged into config file. If the value to '
  26. 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
  27. 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
  28. 'Note that the quotation marks are necessary and that no white space '
  29. 'is allowed.')
  30. parser.add_argument(
  31. '--size-divisor',
  32. type=int,
  33. default=32,
  34. help='Pad the input image, the minimum size that is divisible '
  35. 'by size_divisor, -1 means do not pad the image.')
  36. args = parser.parse_args()
  37. return args
  38. def main():
  39. args = parse_args()
  40. if len(args.shape) == 1:
  41. h = w = args.shape[0]
  42. elif len(args.shape) == 2:
  43. h, w = args.shape
  44. else:
  45. raise ValueError('invalid input shape')
  46. orig_shape = (3, h, w)
  47. divisor = args.size_divisor
  48. if divisor > 0:
  49. h = int(np.ceil(h / divisor)) * divisor
  50. w = int(np.ceil(w / divisor)) * divisor
  51. input_shape = (3, h, w)
  52. cfg = Config.fromfile(args.config)
  53. if args.cfg_options is not None:
  54. cfg.merge_from_dict(args.cfg_options)
  55. # import modules from string list.
  56. if cfg.get('custom_imports', None):
  57. from mmcv.utils import import_modules_from_strings
  58. import_modules_from_strings(**cfg['custom_imports'])
  59. model = build_detector(
  60. cfg.model,
  61. train_cfg=cfg.get('train_cfg'),
  62. test_cfg=cfg.get('test_cfg'))
  63. if torch.cuda.is_available():
  64. model.cuda()
  65. model.eval()
  66. if hasattr(model, 'forward_dummy'):
  67. model.forward = model.forward_dummy
  68. else:
  69. raise NotImplementedError(
  70. 'FLOPs counter is currently not currently supported with {}'.
  71. format(model.__class__.__name__))
  72. flops, params = get_model_complexity_info(model, input_shape)
  73. split_line = '=' * 30
  74. if divisor > 0 and \
  75. input_shape != orig_shape:
  76. print(f'{split_line}\nUse size divisor set input shape '
  77. f'from {orig_shape} to {input_shape}\n')
  78. print(f'{split_line}\nInput shape: {input_shape}\n'
  79. f'Flops: {flops}\nParams: {params}\n{split_line}')
  80. print('!!!Please be cautious if you use the results in papers. '
  81. 'You may need to check if all ops are supported and verify that the '
  82. 'flops computation is correct.')
  83. if __name__ == '__main__':
  84. main()

No Description

Contributors (1)