You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test06.py 2.7 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import sys
  2. import os
  3. import random
  4. import numpy as np
  5. sys.path.append("/home/shanwei-luo/userdata/mmdetection")
  6. from mmdet.apis import (async_inference_detector, inference_detector,
  7. init_detector, show_result_pyplot)
  8. class_AOI_name = {"bu_pi_pei":"1","fang_xiang_fan":"2","err.txt_c_not_f":"3", "shang_xi_bu_lia":"4"}
  9. config_file_1 = '/home/shanwei-luo/userdata/mmdetection/work_dirs/AD_pd_test01/AD_pd_test01.py'
  10. checkpoint_file_1 = '/home/shanwei-luo/userdata/mmdetection/work_dirs/AD_pd_test01/epoch_18.pth'
  11. imgs_ok_path = '/home/shanwei-luo/userdata/pd_datasets/20220605/OK/'
  12. imgs_ng_path = '/home/shanwei-luo/userdata/pd_datasets/20220605/NG/'
  13. model_1 = init_detector(config_file_1, checkpoint_file_1, device='cuda:0')
  14. imgs_ok = os.listdir(imgs_ok_path)
  15. imgs_ng = os.listdir(imgs_ng_path)
  16. count_label_ok = len(imgs_ok)
  17. count_label_ng = len(imgs_ng)
  18. print(count_label_ok,count_label_ng)
  19. imgs_labels = []
  20. imgs_name = []
  21. for img in imgs_ok:
  22. '''img_name = img.split("@")
  23. if img_name[2] in class_AOI_name.keys():
  24. count_label_ok -= 1
  25. continue'''
  26. imgs_labels.append(0)
  27. imgs_name.append(imgs_ok_path+img)
  28. for img in imgs_ng:
  29. '''img_name = img.split("@")
  30. if img_name[2] in class_AOI_name.keys():
  31. count_label_ng -= 1
  32. continue'''
  33. imgs_labels.append(1)
  34. imgs_name.append(imgs_ng_path+img)
  35. print(count_label_ok,count_label_ng, len(imgs_labels))
  36. print("before infer")
  37. index = 0
  38. num = len(imgs_name)
  39. results_1 = []
  40. step = 64
  41. while index<num:
  42. index += step
  43. if index < num:
  44. results_1_tmp = inference_detector(model_1, imgs_name[index-step:index])
  45. else:
  46. results_1_tmp = inference_detector(model_1, imgs_name[index-step:num])
  47. results_1 += results_1_tmp
  48. print(len(results_1))
  49. print("after infer")
  50. #score_thrs = [0.01, 0.011, 0.012, 0.013, 0.014, 0.015, 0.016, 0.017, 0.018, 0.019, 0.02]
  51. for score_thr in np.arange(0.05, 0.1, 0.001):
  52. imgs_results_1 = []
  53. for result in results_1:
  54. res_predict = 0
  55. for i in result:
  56. for j in range(i.shape[0]):
  57. if i[j, 4]>score_thr:
  58. res_predict = 1
  59. imgs_results_1.append(res_predict)
  60. #print(len(imgs_results_1))
  61. count_ng = 0
  62. count_ok = 0
  63. for i in range(len(imgs_labels)):
  64. if imgs_labels[i]==0 and imgs_results_1[i]==0:
  65. count_ok += 1
  66. if imgs_labels[i]==1 and imgs_results_1[i]==1:
  67. count_ng += 1
  68. if imgs_labels[i]==1 and imgs_results_1[i]==0:
  69. print(imgs_name[i])
  70. print("score_thr:", score_thr, " recall(ok):", count_ok/count_label_ok, " recall(ng):", count_ng/count_label_ng)

No Description

Contributors (3)