You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

AOI_to_coco.py 8.2 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. """
  2. YOLO 格式的数据集转化为 COCO 格式的数据集
  3. --root_dir 输入根路径
  4. --save_path 保存文件的名字(没有random_split时使用)
  5. --random_split 有则会随机划分数据集,然后再分别保存为3个文件。
  6. --split_by_file 按照 ./train.txt ./val.txt ./test.txt 来对数据集进行划分。
  7. """
  8. import os
  9. import cv2
  10. import json
  11. from tqdm import tqdm
  12. from sklearn.model_selection import train_test_split
  13. import argparse
  14. import numpy as np
  15. parser = argparse.ArgumentParser()
  16. parser.add_argument('--root_dir', default='/home/shanwei-luo/teamdata/anomaly_detection_active_learning/data0422/smd12_11_12_hard_score_04/train',type=str, help="root path of images and labels, include ./images and ./labels and classes.txt")
  17. parser.add_argument('--save_path', type=str,default='./train.json', help="if not split the dataset, give a path to a json file")
  18. parser.add_argument('--random_split', action='store_true', help="random split the dataset, default ratio is 8:1:1")
  19. parser.add_argument('--split_by_file', action='store_true', help="define how to split the dataset, include ./train.txt ./val.txt ./test.txt ")
  20. arg = parser.parse_args()
  21. def train_test_val_split_random(img_paths,ratio_train=0.8,ratio_test=0.1,ratio_val=0.1):
  22. # 这里可以修改数据集划分的比例。
  23. assert int(ratio_train+ratio_test+ratio_val) == 1
  24. train_img, middle_img = train_test_split(img_paths,test_size=1-ratio_train, random_state=233)
  25. ratio=ratio_val/(1-ratio_train)
  26. val_img, test_img =train_test_split(middle_img,test_size=ratio, random_state=233)
  27. print("NUMS of train:val:test = {}:{}:{}".format(len(train_img), len(val_img), len(test_img)))
  28. return train_img, val_img, test_img
  29. def train_test_val_split_by_files(img_paths, root_dir):
  30. # 根据文件 train.txt, val.txt, test.txt(里面写的都是对应集合的图片名字) 来定义训练集、验证集和测试集
  31. phases = ['train', 'val', 'test']
  32. img_split = []
  33. for p in phases:
  34. define_path = os.path.join(root_dir, f'{p}.txt')
  35. print(f'Read {p} dataset definition from {define_path}')
  36. assert os.path.exists(define_path)
  37. with open(define_path, 'r') as f:
  38. img_paths = f.readlines()
  39. # img_paths = [os.path.split(img_path.strip())[1] for img_path in img_paths] # NOTE 取消这句备注可以读取绝对地址。
  40. img_split.append(img_paths)
  41. return img_split[0], img_split[1], img_split[2]
  42. def yolo2coco(arg):
  43. root_path = arg.root_dir
  44. print("Loading data from ",root_path)
  45. assert os.path.exists(root_path)
  46. originLabelsDir = os.path.join(root_path, 'labels')
  47. originImagesDir = os.path.join(root_path, 'images')
  48. with open(os.path.join(root_path, 'classes.txt')) as f:
  49. classes = f.read().strip().split()
  50. # images dir name
  51. indexes = os.listdir(originImagesDir)
  52. if arg.random_split or arg.split_by_file:
  53. # 用于保存所有数据的图片信息和标注信息
  54. train_dataset = {'categories': [], 'annotations': [], 'images': []}
  55. val_dataset = {'categories': [], 'annotations': [], 'images': []}
  56. test_dataset = {'categories': [], 'annotations': [], 'images': []}
  57. # 建立类别标签和数字id的对应关系, 类别id从0开始。
  58. for i, cls in enumerate(classes, 0):
  59. train_dataset['categories'].append({'id': i, 'name': cls, 'supercategory': 'mark'})
  60. val_dataset['categories'].append({'id': i, 'name': cls, 'supercategory': 'mark'})
  61. test_dataset['categories'].append({'id': i, 'name': cls, 'supercategory': 'mark'})
  62. if arg.random_split:
  63. print("spliting mode: random split")
  64. train_img, val_img, test_img = train_test_val_split_random(indexes,0.8,0.1,0.1)
  65. elif arg.split_by_file:
  66. print("spliting mode: split by files")
  67. train_img, val_img, test_img = train_test_val_split_by_files(indexes, root_path)
  68. else:
  69. dataset = {'categories': [], 'annotations': [], 'images': []}
  70. for i, cls in enumerate(classes, 0):
  71. dataset['categories'].append({'id': i, 'name': cls, 'supercategory': 'mark'})
  72. # 标注的id
  73. ann_id_cnt = 0
  74. ans = 0
  75. for k, index in enumerate(tqdm(indexes)):
  76. # 支持 png jpg 格式的图片。
  77. txtFile = index.replace('images','txt').replace('.jpg','.txt').replace('.png','.txt')
  78. print(txtFile)
  79. # 读取图像的宽和高
  80. #im = cv2.imread(os.path.join(root_path, 'images/') + index)
  81. im = cv2.imdecode(np.fromfile(os.path.join(root_path, 'images/') + index, dtype=np.uint8), cv2.IMREAD_COLOR)
  82. height, width, _ = im.shape
  83. if arg.random_split or arg.split_by_file:
  84. # 切换dataset的引用对象,从而划分数据集
  85. if index in train_img:
  86. dataset = train_dataset
  87. elif index in val_img:
  88. dataset = val_dataset
  89. elif index in test_img:
  90. dataset = test_dataset
  91. # 添加图像的信息
  92. dataset['images'].append({'file_name': index,
  93. 'id': k,
  94. 'width': width,
  95. 'height': height})
  96. if not os.path.exists(os.path.join(originLabelsDir, txtFile)):
  97. # 如没标签,跳过,只保留图片信息。
  98. continue
  99. with open(os.path.join(originLabelsDir, txtFile), 'r') as fr:
  100. labelList = fr.readline()
  101. labelList = labelList.strip().split(" ")
  102. if len(labelList)==1:
  103. continue
  104. labelList = labelList[1:]
  105. for label in labelList:
  106. label = label.split(",")
  107. cls_id = int(label[4]) - 1
  108. x1 = float(label[0])
  109. y1 = float(label[1])
  110. x2 = float(label[2])
  111. y2 = float(label[3])
  112. # convert x,y,w,h to x1,y1,x2,y2
  113. H, W, _ = im.shape
  114. if x1<0:
  115. x1 = 0
  116. elif x1>W:
  117. x1 = W-1
  118. if x2<0:
  119. x2 = 0
  120. elif x2>W:
  121. x2 = W-1
  122. if y1<0:
  123. y1 = 0
  124. elif y1>H:
  125. y1 = H-1
  126. if y2<0:
  127. y2 = 0
  128. elif y2>H:
  129. y2 = H-1
  130. # 标签序号从0开始计算, coco2017数据集标号混乱,不管它了。
  131. width = max(0, x2 - x1)
  132. height = max(0, y2 - y1)
  133. dataset['annotations'].append({
  134. 'area': width * height,
  135. 'bbox': [x1, y1, width, height],
  136. 'category_id': cls_id,
  137. 'id': ann_id_cnt,
  138. 'image_id': k,
  139. 'iscrowd': 0,
  140. # mask, 矩形是从左上角点按顺时针的四个顶点
  141. 'segmentation': [[x1, y1, x2, y1, x2, y2, x1, y2]]
  142. })
  143. ann_id_cnt += 1
  144. # 保存结果
  145. print(ann_id_cnt)
  146. folder = os.path.join(root_path, 'annotations')
  147. if not os.path.exists(folder):
  148. os.makedirs(folder)
  149. if arg.random_split or arg.split_by_file:
  150. for phase in ['train','val','test']:
  151. json_name = os.path.join(root_path, 'annotations/{}.json'.format(phase))
  152. with open(json_name, 'w') as f:
  153. if phase == 'train':
  154. json.dump(train_dataset, f)
  155. elif phase == 'val':
  156. json.dump(val_dataset, f)
  157. elif phase == 'test':
  158. json.dump(test_dataset, f)
  159. print('Save annotation to {}'.format(json_name))
  160. else:
  161. json_name = os.path.join(root_path, 'annotations/{}'.format(arg.save_path))
  162. with open(json_name, 'w') as f:
  163. json.dump(dataset, f)
  164. print('Save annotation to {}'.format(json_name))
  165. if __name__ == "__main__":
  166. yolo2coco(arg)

No Description

Contributors (3)