You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cityscapes.py 5.2 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import glob
  4. import os.path as osp
  5. import cityscapesscripts.helpers.labels as CSLabels
  6. import mmcv
  7. import numpy as np
  8. import pycocotools.mask as maskUtils
  9. def collect_files(img_dir, gt_dir):
  10. suffix = 'leftImg8bit.png'
  11. files = []
  12. for img_file in glob.glob(osp.join(img_dir, '**/*.png')):
  13. assert img_file.endswith(suffix), img_file
  14. inst_file = gt_dir + img_file[
  15. len(img_dir):-len(suffix)] + 'gtFine_instanceIds.png'
  16. # Note that labelIds are not converted to trainId for seg map
  17. segm_file = gt_dir + img_file[
  18. len(img_dir):-len(suffix)] + 'gtFine_labelIds.png'
  19. files.append((img_file, inst_file, segm_file))
  20. assert len(files), f'No images found in {img_dir}'
  21. print(f'Loaded {len(files)} images from {img_dir}')
  22. return files
  23. def collect_annotations(files, nproc=1):
  24. print('Loading annotation images')
  25. if nproc > 1:
  26. images = mmcv.track_parallel_progress(
  27. load_img_info, files, nproc=nproc)
  28. else:
  29. images = mmcv.track_progress(load_img_info, files)
  30. return images
  31. def load_img_info(files):
  32. img_file, inst_file, segm_file = files
  33. inst_img = mmcv.imread(inst_file, 'unchanged')
  34. # ids < 24 are stuff labels (filtering them first is about 5% faster)
  35. unique_inst_ids = np.unique(inst_img[inst_img >= 24])
  36. anno_info = []
  37. for inst_id in unique_inst_ids:
  38. # For non-crowd annotations, inst_id // 1000 is the label_id
  39. # Crowd annotations have <1000 instance ids
  40. label_id = inst_id // 1000 if inst_id >= 1000 else inst_id
  41. label = CSLabels.id2label[label_id]
  42. if not label.hasInstances or label.ignoreInEval:
  43. continue
  44. category_id = label.id
  45. iscrowd = int(inst_id < 1000)
  46. mask = np.asarray(inst_img == inst_id, dtype=np.uint8, order='F')
  47. mask_rle = maskUtils.encode(mask[:, :, None])[0]
  48. area = maskUtils.area(mask_rle)
  49. # convert to COCO style XYWH format
  50. bbox = maskUtils.toBbox(mask_rle)
  51. # for json encoding
  52. mask_rle['counts'] = mask_rle['counts'].decode()
  53. anno = dict(
  54. iscrowd=iscrowd,
  55. category_id=category_id,
  56. bbox=bbox.tolist(),
  57. area=area.tolist(),
  58. segmentation=mask_rle)
  59. anno_info.append(anno)
  60. video_name = osp.basename(osp.dirname(img_file))
  61. img_info = dict(
  62. # remove img_prefix for filename
  63. file_name=osp.join(video_name, osp.basename(img_file)),
  64. height=inst_img.shape[0],
  65. width=inst_img.shape[1],
  66. anno_info=anno_info,
  67. segm_file=osp.join(video_name, osp.basename(segm_file)))
  68. return img_info
  69. def cvt_annotations(image_infos, out_json_name):
  70. out_json = dict()
  71. img_id = 0
  72. ann_id = 0
  73. out_json['images'] = []
  74. out_json['categories'] = []
  75. out_json['annotations'] = []
  76. for image_info in image_infos:
  77. image_info['id'] = img_id
  78. anno_infos = image_info.pop('anno_info')
  79. out_json['images'].append(image_info)
  80. for anno_info in anno_infos:
  81. anno_info['image_id'] = img_id
  82. anno_info['id'] = ann_id
  83. out_json['annotations'].append(anno_info)
  84. ann_id += 1
  85. img_id += 1
  86. for label in CSLabels.labels:
  87. if label.hasInstances and not label.ignoreInEval:
  88. cat = dict(id=label.id, name=label.name)
  89. out_json['categories'].append(cat)
  90. if len(out_json['annotations']) == 0:
  91. out_json.pop('annotations')
  92. mmcv.dump(out_json, out_json_name)
  93. return out_json
  94. def parse_args():
  95. parser = argparse.ArgumentParser(
  96. description='Convert Cityscapes annotations to COCO format')
  97. parser.add_argument('cityscapes_path', help='cityscapes data path')
  98. parser.add_argument('--img-dir', default='leftImg8bit', type=str)
  99. parser.add_argument('--gt-dir', default='gtFine', type=str)
  100. parser.add_argument('-o', '--out-dir', help='output path')
  101. parser.add_argument(
  102. '--nproc', default=1, type=int, help='number of process')
  103. args = parser.parse_args()
  104. return args
  105. def main():
  106. args = parse_args()
  107. cityscapes_path = args.cityscapes_path
  108. out_dir = args.out_dir if args.out_dir else cityscapes_path
  109. mmcv.mkdir_or_exist(out_dir)
  110. img_dir = osp.join(cityscapes_path, args.img_dir)
  111. gt_dir = osp.join(cityscapes_path, args.gt_dir)
  112. set_name = dict(
  113. train='instancesonly_filtered_gtFine_train.json',
  114. val='instancesonly_filtered_gtFine_val.json',
  115. test='instancesonly_filtered_gtFine_test.json')
  116. for split, json_name in set_name.items():
  117. print(f'Converting {split} into {json_name}')
  118. with mmcv.Timer(
  119. print_tmpl='It took {}s to convert Cityscapes annotation'):
  120. files = collect_files(
  121. osp.join(img_dir, split), osp.join(gt_dir, split))
  122. image_infos = collect_annotations(files, nproc=args.nproc)
  123. cvt_annotations(image_infos, osp.join(out_dir, json_name))
  124. if __name__ == '__main__':
  125. main()

No Description

Contributors (3)