You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

select_threshold.py 3.4 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import sys
  2. import os
  3. import random
  4. import numpy as np
  5. sys.path.append("/home/shanwei-luo/userdata/mmdetection")
  6. from mmdet.apis import (async_inference_detector, inference_detector,
  7. init_detector, show_result_pyplot)
  8. import argparse
  9. #python select_threshold.py --config_file /home/shanwei-luo/teamdata/anomaly_detection_active_learning/model/work_dirs/AD_dsxw_test66_06_10/AD_dsxw_test66_06_10.py --checkpoint_file /home/shanwei-luo/teamdata/anomaly_detection_active_learning/model/work_dirs/AD_dsxw_test66_06_10/latest.pth --images_path /home/shanwei-luo/teamdata/anomaly_detection_active_learning/data0422/smd12_2106_10/test/ --test_batch_size 128
  10. def parse_args():
  11. parser = argparse.ArgumentParser(description='get best threshold')
  12. parser.add_argument('--config_file', help='config')
  13. parser.add_argument('--checkpoint_file', help='checkpoint')
  14. parser.add_argument('--images_path', help='images')
  15. parser.add_argument('--test_batch_size', help='images')
  16. args = parser.parse_args()
  17. return args
  18. args = parse_args()
  19. class_AOI_name = {"bu_pi_pei":"1","fang_xiang_fan":"2","err.txt_c_not_f":"3", "shang_xi_bu_lia":"4"}
  20. config_file_1 = args.config_file
  21. checkpoint_file_1 = args.checkpoint_file
  22. imgs_ok_path = args.images_path+'ok/'
  23. imgs_ng_path = args.images_path+'ng/'
  24. model_1 = init_detector(config_file_1, checkpoint_file_1, device='cuda:0')
  25. imgs_ok = os.listdir(imgs_ok_path)
  26. imgs_ng = os.listdir(imgs_ng_path)
  27. count_label_ok = len(imgs_ok)
  28. count_label_ng = len(imgs_ng)
  29. print(count_label_ok,count_label_ng)
  30. imgs_labels = []
  31. imgs_name = []
  32. for img in imgs_ok:
  33. img_name = img.split("@")
  34. if img_name[2] in class_AOI_name.keys():
  35. count_label_ok -= 1
  36. continue
  37. imgs_labels.append(0)
  38. imgs_name.append(imgs_ok_path+img)
  39. for img in imgs_ng:
  40. img_name = img.split("@")
  41. if img_name[2] in class_AOI_name.keys():
  42. count_label_ng -= 1
  43. continue
  44. imgs_labels.append(1)
  45. imgs_name.append(imgs_ng_path+img)
  46. print(count_label_ok,count_label_ng, len(imgs_labels))
  47. print("before infer")
  48. index = 0
  49. num = len(imgs_name)
  50. results_1 = []
  51. step = int(args.test_batch_size)
  52. while index<num:
  53. index += step
  54. if index < num:
  55. results_1_tmp = inference_detector(model_1, imgs_name[index-step:index])
  56. else:
  57. results_1_tmp = inference_detector(model_1, imgs_name[index-step:num])
  58. results_1 += results_1_tmp
  59. print(len(results_1))
  60. print("after infer")
  61. recall_ok_ng = 0
  62. best_thr = 0.005
  63. for score_thr in np.arange(0.01, 0.3, 0.005):
  64. imgs_results_1 = []
  65. for result in results_1:
  66. res_predict = 0
  67. for i in result:
  68. for j in range(i.shape[0]):
  69. if i[j, 4]>score_thr:
  70. res_predict = 1
  71. imgs_results_1.append(res_predict)
  72. count_ng = 0
  73. count_ok = 0
  74. for i in range(len(imgs_labels)):
  75. if imgs_labels[i]==0 and imgs_results_1[i]==0:
  76. count_ok += 1
  77. if imgs_labels[i]==1 and imgs_results_1[i]==1:
  78. count_ng += 1
  79. recall_ok = count_ok/count_label_ok
  80. recall_ng = count_ng/count_label_ng
  81. print("score_thr:", score_thr, " recall(ok):", recall_ok, " recall(ng):", recall_ng)
  82. if recall_ok_ng < 0.3*recall_ok + 0.7*recall_ng:
  83. recall_ok_ng = 0.3*recall_ok + 0.7*recall_ng
  84. best_thr = score_thr
  85. print("***********************")
  86. print("best threshold:", best_thr)

No Description

Contributors (3)