import sys import os import random import numpy as np sys.path.append("/home/shanwei-luo/userdata/mmdetection") from mmdet.apis import (async_inference_detector, inference_detector, init_detector, show_result_pyplot) import argparse #python select_threshold.py --config_file /home/shanwei-luo/teamdata/anomaly_detection_active_learning/model/work_dirs/AD_dsxw_test66_06_10/AD_dsxw_test66_06_10.py --checkpoint_file /home/shanwei-luo/teamdata/anomaly_detection_active_learning/model/work_dirs/AD_dsxw_test66_06_10/latest.pth --images_path /home/shanwei-luo/teamdata/anomaly_detection_active_learning/data0422/smd12_2106_10/test/ --test_batch_size 128 def parse_args(): parser = argparse.ArgumentParser(description='get best threshold') parser.add_argument('--config_file', help='config') parser.add_argument('--checkpoint_file', help='checkpoint') parser.add_argument('--images_path', help='images') parser.add_argument('--test_batch_size', help='images') args = parser.parse_args() return args args = parse_args() class_AOI_name = {"bu_pi_pei":"1","fang_xiang_fan":"2","err.txt_c_not_f":"3", "shang_xi_bu_lia":"4"} config_file_1 = args.config_file checkpoint_file_1 = args.checkpoint_file imgs_ok_path = args.images_path+'ok/' imgs_ng_path = args.images_path+'ng/' model_1 = init_detector(config_file_1, checkpoint_file_1, device='cuda:0') imgs_ok = os.listdir(imgs_ok_path) imgs_ng = os.listdir(imgs_ng_path) count_label_ok = len(imgs_ok) count_label_ng = len(imgs_ng) print(count_label_ok,count_label_ng) imgs_labels = [] imgs_name = [] for img in imgs_ok: img_name = img.split("@") if img_name[2] in class_AOI_name.keys(): count_label_ok -= 1 continue imgs_labels.append(0) imgs_name.append(imgs_ok_path+img) for img in imgs_ng: img_name = img.split("@") if img_name[2] in class_AOI_name.keys(): count_label_ng -= 1 continue imgs_labels.append(1) imgs_name.append(imgs_ng_path+img) print(count_label_ok,count_label_ng, len(imgs_labels)) print("before infer") index = 0 num = len(imgs_name) results_1 = [] step = int(args.test_batch_size) while indexscore_thr: res_predict = 1 imgs_results_1.append(res_predict) count_ng = 0 count_ok = 0 for i in range(len(imgs_labels)): if imgs_labels[i]==0 and imgs_results_1[i]==0: count_ok += 1 if imgs_labels[i]==1 and imgs_results_1[i]==1: count_ng += 1 recall_ok = count_ok/count_label_ok recall_ng = count_ng/count_label_ng print("score_thr:", score_thr, " recall(ok):", recall_ok, " recall(ng):", recall_ng) if recall_ok_ng < 0.3*recall_ok + 0.7*recall_ng: recall_ok_ng = 0.3*recall_ok + 0.7*recall_ng best_thr = score_thr print("***********************") print("best threshold:", best_thr)