You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_coco.py 2.9 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import sys
  2. import os
  3. import random
  4. import numpy as np
  5. import json
  6. sys.path.append("/home/shanwei-luo/userdata/mmdetection")
  7. from mmdet.apis import (async_inference_detector, inference_detector,
  8. init_detector, show_result_pyplot)
  9. data_coco = json.load(open('/home/shanwei-luo/userdata/pd_datasets/20220606_coco/annotations/val_cat_mode.json'))
  10. data_name = data_coco["images"]
  11. data_ann = data_coco['annotations']
  12. boxes = {}
  13. for res in data_ann:
  14. print(res)
  15. img_id_1 = res["image_id"]
  16. img_name = data_name[int(img_id_1)]["file_name"]
  17. #print(img_name)
  18. #print(res)
  19. bbox = res["bbox"]
  20. label = res["category_id"]
  21. bbox.append(int(label))
  22. if img_name in boxes.keys():
  23. boxes[img_name].append(bbox)
  24. else:
  25. boxes[img_name]=[]
  26. boxes[img_name].append(bbox)
  27. #print(boxes)
  28. config_file_1 = '/home/shanwei-luo/userdata/mmdetection/work_dirs/AD_pd_test01/AD_pd_test01.py'
  29. checkpoint_file_1 = '/home/shanwei-luo/userdata/mmdetection/work_dirs/AD_pd_test01/epoch_18.pth'
  30. img_path = '/home/shanwei-luo/userdata/pd_datasets/20220606_coco/images/'
  31. model_1 = init_detector(config_file_1, checkpoint_file_1, device='cuda:0')
  32. imgs = os.listdir(img_path)
  33. #img_id = random.randint(0, len(label_path))
  34. imgs_labels = []
  35. imgs_name = []
  36. num_ng = 0
  37. for i in range(len(data_name)):
  38. res_label = 0
  39. if data_name[i]["file_name"] in boxes.keys():
  40. res_label = 1
  41. num_ng += 1
  42. imgs_labels.append(res_label)
  43. imgs_name.append(img_path+data_name[i]["file_name"])
  44. num_ok = len(data_name)-num_ng
  45. print(len(imgs_labels), num_ok, num_ng)
  46. imgs_name = imgs_name[:10]
  47. print("before infer")
  48. index = 0
  49. num = len(imgs_name)
  50. results_1 = []
  51. step = 64
  52. while index<num:
  53. index += step
  54. if index < num:
  55. results_1_tmp = inference_detector(model_1, imgs_name[index-step:index])
  56. else:
  57. results_1_tmp = inference_detector(model_1, imgs_name[index-step:num])
  58. results_1 += results_1_tmp
  59. print(len(results_1))
  60. print("after infer")
  61. #score_thrs = [0.01, 0.011, 0.012, 0.013, 0.014, 0.015, 0.016, 0.017, 0.018, 0.019, 0.02]
  62. for score_thr in np.arange(0.01, 0.2, 0.01):
  63. imgs_results_1 = []
  64. for result in results_1:
  65. res_predict = 0
  66. print(len(result))
  67. for i in result:
  68. print(i.shape)
  69. for j in range(i.shape[0]):
  70. if i[j, 4]>score_thr:
  71. res_predict = 1
  72. imgs_results_1.append(res_predict)
  73. count_ng = 0
  74. count_ok = 0
  75. for i in range(len(imgs_labels)):
  76. if imgs_labels[i]==0 and imgs_results_1[i]==0:
  77. count_ok += 1
  78. if imgs_labels[i]==1 and imgs_results_1[i]==1:
  79. count_ng += 1
  80. '''if imgs_labels[i]==1 and imgs_results_1[i]==0:
  81. print(imgs_name[i])'''
  82. print("score_thr:", score_thr, " recall(ok):", count_ok/num_ok, " recall(ng):", count_ng/num_ng)

No Description

Contributors (3)