You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimize_anchors.py 13 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. """Optimize anchor settings on a specific dataset.
  3. This script provides two method to optimize YOLO anchors including k-means
  4. anchor cluster and differential evolution. You can use ``--algorithm k-means``
  5. and ``--algorithm differential_evolution`` to switch two method.
  6. Example:
  7. Use k-means anchor cluster::
  8. python tools/analysis_tools/optimize_anchors.py ${CONFIG} \
  9. --algorithm k-means --input-shape ${INPUT_SHAPE [WIDTH HEIGHT]} \
  10. --output-dir ${OUTPUT_DIR}
  11. Use differential evolution to optimize anchors::
  12. python tools/analysis_tools/optimize_anchors.py ${CONFIG} \
  13. --algorithm differential_evolution \
  14. --input-shape ${INPUT_SHAPE [WIDTH HEIGHT]} \
  15. --output-dir ${OUTPUT_DIR}
  16. """
  17. import argparse
  18. import os.path as osp
  19. import mmcv
  20. import numpy as np
  21. import torch
  22. from mmcv import Config
  23. from scipy.optimize import differential_evolution
  24. from mmdet.core import bbox_cxcywh_to_xyxy, bbox_overlaps, bbox_xyxy_to_cxcywh
  25. from mmdet.datasets import build_dataset
  26. from mmdet.utils import get_root_logger
  27. def parse_args():
  28. parser = argparse.ArgumentParser(description='Optimize anchor parameters.')
  29. parser.add_argument('config', help='Train config file path.')
  30. parser.add_argument(
  31. '--device', default='cuda:0', help='Device used for calculating.')
  32. parser.add_argument(
  33. '--input-shape',
  34. type=int,
  35. nargs='+',
  36. default=[608, 608],
  37. help='input image size')
  38. parser.add_argument(
  39. '--algorithm',
  40. default='differential_evolution',
  41. help='Algorithm used for anchor optimizing.'
  42. 'Support k-means and differential_evolution for YOLO.')
  43. parser.add_argument(
  44. '--iters',
  45. default=1000,
  46. type=int,
  47. help='Maximum iterations for optimizer.')
  48. parser.add_argument(
  49. '--output-dir',
  50. default=None,
  51. type=str,
  52. help='Path to save anchor optimize result.')
  53. args = parser.parse_args()
  54. return args
  55. class BaseAnchorOptimizer:
  56. """Base class for anchor optimizer.
  57. Args:
  58. dataset (obj:`Dataset`): Dataset object.
  59. input_shape (list[int]): Input image shape of the model.
  60. Format in [width, height].
  61. logger (obj:`logging.Logger`): The logger for logging.
  62. device (str, optional): Device used for calculating.
  63. Default: 'cuda:0'
  64. out_dir (str, optional): Path to save anchor optimize result.
  65. Default: None
  66. """
  67. def __init__(self,
  68. dataset,
  69. input_shape,
  70. logger,
  71. device='cuda:0',
  72. out_dir=None):
  73. self.dataset = dataset
  74. self.input_shape = input_shape
  75. self.logger = logger
  76. self.device = device
  77. self.out_dir = out_dir
  78. bbox_whs, img_shapes = self.get_whs_and_shapes()
  79. ratios = img_shapes.max(1, keepdims=True) / np.array([input_shape])
  80. # resize to input shape
  81. self.bbox_whs = bbox_whs / ratios
  82. def get_whs_and_shapes(self):
  83. """Get widths and heights of bboxes and shapes of images.
  84. Returns:
  85. tuple[np.ndarray]: Array of bbox shapes and array of image
  86. shapes with shape (num_bboxes, 2) in [width, height] format.
  87. """
  88. self.logger.info('Collecting bboxes from annotation...')
  89. bbox_whs = []
  90. img_shapes = []
  91. prog_bar = mmcv.ProgressBar(len(self.dataset))
  92. for idx in range(len(self.dataset)):
  93. ann = self.dataset.get_ann_info(idx)
  94. data_info = self.dataset.data_infos[idx]
  95. img_shape = np.array([data_info['width'], data_info['height']])
  96. gt_bboxes = ann['bboxes']
  97. for bbox in gt_bboxes:
  98. wh = bbox[2:4] - bbox[0:2]
  99. img_shapes.append(img_shape)
  100. bbox_whs.append(wh)
  101. prog_bar.update()
  102. print('\n')
  103. bbox_whs = np.array(bbox_whs)
  104. img_shapes = np.array(img_shapes)
  105. self.logger.info(f'Collected {bbox_whs.shape[0]} bboxes.')
  106. return bbox_whs, img_shapes
  107. def get_zero_center_bbox_tensor(self):
  108. """Get a tensor of bboxes centered at (0, 0).
  109. Returns:
  110. Tensor: Tensor of bboxes with shape (num_bboxes, 4)
  111. in [xmin, ymin, xmax, ymax] format.
  112. """
  113. whs = torch.from_numpy(self.bbox_whs).to(
  114. self.device, dtype=torch.float32)
  115. bboxes = bbox_cxcywh_to_xyxy(
  116. torch.cat([torch.zeros_like(whs), whs], dim=1))
  117. return bboxes
  118. def optimize(self):
  119. raise NotImplementedError
  120. def save_result(self, anchors, path=None):
  121. anchor_results = []
  122. for w, h in anchors:
  123. anchor_results.append([round(w), round(h)])
  124. self.logger.info(f'Anchor optimize result:{anchor_results}')
  125. if path:
  126. json_path = osp.join(path, 'anchor_optimize_result.json')
  127. mmcv.dump(anchor_results, json_path)
  128. self.logger.info(f'Result saved in {json_path}')
  129. class YOLOKMeansAnchorOptimizer(BaseAnchorOptimizer):
  130. r"""YOLO anchor optimizer using k-means. Code refer to `AlexeyAB/darknet.
  131. <https://github.com/AlexeyAB/darknet/blob/master/src/detector.c>`_.
  132. Args:
  133. num_anchors (int) : Number of anchors.
  134. iters (int): Maximum iterations for k-means.
  135. """
  136. def __init__(self, num_anchors, iters, **kwargs):
  137. super(YOLOKMeansAnchorOptimizer, self).__init__(**kwargs)
  138. self.num_anchors = num_anchors
  139. self.iters = iters
  140. def optimize(self):
  141. anchors = self.kmeans_anchors()
  142. self.save_result(anchors, self.out_dir)
  143. def kmeans_anchors(self):
  144. self.logger.info(
  145. f'Start cluster {self.num_anchors} YOLO anchors with K-means...')
  146. bboxes = self.get_zero_center_bbox_tensor()
  147. cluster_center_idx = torch.randint(
  148. 0, bboxes.shape[0], (self.num_anchors, )).to(self.device)
  149. assignments = torch.zeros((bboxes.shape[0], )).to(self.device)
  150. cluster_centers = bboxes[cluster_center_idx]
  151. if self.num_anchors == 1:
  152. cluster_centers = self.kmeans_maximization(bboxes, assignments,
  153. cluster_centers)
  154. anchors = bbox_xyxy_to_cxcywh(cluster_centers)[:, 2:].cpu().numpy()
  155. anchors = sorted(anchors, key=lambda x: x[0] * x[1])
  156. return anchors
  157. prog_bar = mmcv.ProgressBar(self.iters)
  158. for i in range(self.iters):
  159. converged, assignments = self.kmeans_expectation(
  160. bboxes, assignments, cluster_centers)
  161. if converged:
  162. self.logger.info(f'K-means process has converged at iter {i}.')
  163. break
  164. cluster_centers = self.kmeans_maximization(bboxes, assignments,
  165. cluster_centers)
  166. prog_bar.update()
  167. print('\n')
  168. avg_iou = bbox_overlaps(bboxes,
  169. cluster_centers).max(1)[0].mean().item()
  170. anchors = bbox_xyxy_to_cxcywh(cluster_centers)[:, 2:].cpu().numpy()
  171. anchors = sorted(anchors, key=lambda x: x[0] * x[1])
  172. self.logger.info(f'Anchor cluster finish. Average IOU: {avg_iou}')
  173. return anchors
  174. def kmeans_maximization(self, bboxes, assignments, centers):
  175. """Maximization part of EM algorithm(Expectation-Maximization)"""
  176. new_centers = torch.zeros_like(centers)
  177. for i in range(centers.shape[0]):
  178. mask = (assignments == i)
  179. if mask.sum():
  180. new_centers[i, :] = bboxes[mask].mean(0)
  181. return new_centers
  182. def kmeans_expectation(self, bboxes, assignments, centers):
  183. """Expectation part of EM algorithm(Expectation-Maximization)"""
  184. ious = bbox_overlaps(bboxes, centers)
  185. closest = ious.argmax(1)
  186. converged = (closest == assignments).all()
  187. return converged, closest
  188. class YOLODEAnchorOptimizer(BaseAnchorOptimizer):
  189. """YOLO anchor optimizer using differential evolution algorithm.
  190. Args:
  191. num_anchors (int) : Number of anchors.
  192. iters (int): Maximum iterations for k-means.
  193. strategy (str): The differential evolution strategy to use.
  194. Should be one of:
  195. - 'best1bin'
  196. - 'best1exp'
  197. - 'rand1exp'
  198. - 'randtobest1exp'
  199. - 'currenttobest1exp'
  200. - 'best2exp'
  201. - 'rand2exp'
  202. - 'randtobest1bin'
  203. - 'currenttobest1bin'
  204. - 'best2bin'
  205. - 'rand2bin'
  206. - 'rand1bin'
  207. Default: 'best1bin'.
  208. population_size (int): Total population size of evolution algorithm.
  209. Default: 15.
  210. convergence_thr (float): Tolerance for convergence, the
  211. optimizing stops when ``np.std(pop) <= abs(convergence_thr)
  212. + convergence_thr * np.abs(np.mean(population_energies))``,
  213. respectively. Default: 0.0001.
  214. mutation (tuple[float]): Range of dithering randomly changes the
  215. mutation constant. Default: (0.5, 1).
  216. recombination (float): Recombination constant of crossover probability.
  217. Default: 0.7.
  218. """
  219. def __init__(self,
  220. num_anchors,
  221. iters,
  222. strategy='best1bin',
  223. population_size=15,
  224. convergence_thr=0.0001,
  225. mutation=(0.5, 1),
  226. recombination=0.7,
  227. **kwargs):
  228. super(YOLODEAnchorOptimizer, self).__init__(**kwargs)
  229. self.num_anchors = num_anchors
  230. self.iters = iters
  231. self.strategy = strategy
  232. self.population_size = population_size
  233. self.convergence_thr = convergence_thr
  234. self.mutation = mutation
  235. self.recombination = recombination
  236. def optimize(self):
  237. anchors = self.differential_evolution()
  238. self.save_result(anchors, self.out_dir)
  239. def differential_evolution(self):
  240. bboxes = self.get_zero_center_bbox_tensor()
  241. bounds = []
  242. for i in range(self.num_anchors):
  243. bounds.extend([(0, self.input_shape[0]), (0, self.input_shape[1])])
  244. result = differential_evolution(
  245. func=self.avg_iou_cost,
  246. bounds=bounds,
  247. args=(bboxes, ),
  248. strategy=self.strategy,
  249. maxiter=self.iters,
  250. popsize=self.population_size,
  251. tol=self.convergence_thr,
  252. mutation=self.mutation,
  253. recombination=self.recombination,
  254. updating='immediate',
  255. disp=True)
  256. self.logger.info(
  257. f'Anchor evolution finish. Average IOU: {1 - result.fun}')
  258. anchors = [(w, h) for w, h in zip(result.x[::2], result.x[1::2])]
  259. anchors = sorted(anchors, key=lambda x: x[0] * x[1])
  260. return anchors
  261. @staticmethod
  262. def avg_iou_cost(anchor_params, bboxes):
  263. assert len(anchor_params) % 2 == 0
  264. anchor_whs = torch.tensor(
  265. [[w, h]
  266. for w, h in zip(anchor_params[::2], anchor_params[1::2])]).to(
  267. bboxes.device, dtype=bboxes.dtype)
  268. anchor_boxes = bbox_cxcywh_to_xyxy(
  269. torch.cat([torch.zeros_like(anchor_whs), anchor_whs], dim=1))
  270. ious = bbox_overlaps(bboxes, anchor_boxes)
  271. max_ious, _ = ious.max(1)
  272. cost = 1 - max_ious.mean().item()
  273. return cost
  274. def main():
  275. logger = get_root_logger()
  276. args = parse_args()
  277. cfg = args.config
  278. cfg = Config.fromfile(cfg)
  279. input_shape = args.input_shape
  280. assert len(input_shape) == 2
  281. anchor_type = cfg.model.bbox_head.anchor_generator.type
  282. assert anchor_type == 'YOLOAnchorGenerator', \
  283. f'Only support optimize YOLOAnchor, but get {anchor_type}.'
  284. base_sizes = cfg.model.bbox_head.anchor_generator.base_sizes
  285. num_anchors = sum([len(sizes) for sizes in base_sizes])
  286. train_data_cfg = cfg.data.train
  287. while 'dataset' in train_data_cfg:
  288. train_data_cfg = train_data_cfg['dataset']
  289. dataset = build_dataset(train_data_cfg)
  290. if args.algorithm == 'k-means':
  291. optimizer = YOLOKMeansAnchorOptimizer(
  292. dataset=dataset,
  293. input_shape=input_shape,
  294. device=args.device,
  295. num_anchors=num_anchors,
  296. iters=args.iters,
  297. logger=logger,
  298. out_dir=args.output_dir)
  299. elif args.algorithm == 'differential_evolution':
  300. optimizer = YOLODEAnchorOptimizer(
  301. dataset=dataset,
  302. input_shape=input_shape,
  303. device=args.device,
  304. num_anchors=num_anchors,
  305. iters=args.iters,
  306. logger=logger,
  307. out_dir=args.output_dir)
  308. else:
  309. raise NotImplementedError(
  310. f'Only support k-means and differential_evolution, '
  311. f'but get {args.algorithm}')
  312. optimizer.optimize()
  313. if __name__ == '__main__':
  314. main()

No Description

Contributors (3)