You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test05.py 4.0 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import sys
  2. import os
  3. import random
  4. sys.path.append("/home/shanwei-luo/userdata/mmdetection")
  5. from mmdet.apis import (async_inference_detector, inference_detector,
  6. init_detector, show_result_pyplot)
  7. class_AOI_name = {"bu_pi_pei":"1","fang_xiang_fan":"2","err.txt_c_not_f":"3", "shang_xi_bu_lia":"4"}
  8. config_file_1 = '/home/shanwei-luo/userdata/mmdetection/work_dirs/AD_dsxw_test61/AD_dsxw_test61.py'
  9. checkpoint_file_1 = '/home/shanwei-luo/userdata/mmdetection/work_dirs/AD_dsxw_test61/epoch_36.pth'
  10. config_file_2 = '/home/shanwei-luo/userdata/mmdetection/work_dirs/AD_dsxw_test62/AD_dsxw_test62.py'
  11. checkpoint_file_2 = '/home/shanwei-luo/userdata/mmdetection/work_dirs/AD_dsxw_test62/epoch_44.pth'
  12. imgs_ok_path = '/home/shanwei-luo/userdata/datasets/dsxw_test_SMD12_2022_0301_0314/ok/'
  13. imgs_ng_path = '/home/shanwei-luo/userdata/datasets/dsxw_test_SMD12_2022_0301_0314/ng/'
  14. model_1 = init_detector(config_file_1, checkpoint_file_1, device='cuda:0')
  15. model_2 = init_detector(config_file_2, checkpoint_file_2, device='cuda:1')
  16. imgs_ok = os.listdir(imgs_ok_path)
  17. imgs_ng = os.listdir(imgs_ng_path)
  18. count_label_ok = len(imgs_ok)
  19. count_label_ng = len(imgs_ng)
  20. print(count_label_ok,count_label_ng)
  21. imgs_labels = []
  22. imgs_name = []
  23. for img in imgs_ok:
  24. img_name = img.split("@")
  25. if img_name[2] in class_AOI_name.keys():
  26. count_label_ok -= 1
  27. continue
  28. imgs_labels.append(0)
  29. imgs_name.append(imgs_ok_path+img)
  30. for img in imgs_ng:
  31. img_name = img.split("@")
  32. if img_name[2] in class_AOI_name.keys():
  33. count_label_ng -= 1
  34. continue
  35. imgs_labels.append(1)
  36. imgs_name.append(imgs_ng_path+img)
  37. print(count_label_ok,count_label_ng, len(imgs_labels))
  38. print("before infer")
  39. index = 0
  40. num = len(imgs_name)
  41. results_1 = []
  42. results_2 = []
  43. step = 256
  44. while index<num:
  45. index += step
  46. if index < num:
  47. results_1_tmp = inference_detector(model_1, imgs_name[index-step:index])
  48. results_2_tmp = inference_detector(model_2, imgs_name[index-step:index])
  49. else:
  50. results_1_tmp = inference_detector(model_1, imgs_name[index-step:num])
  51. results_2_tmp = inference_detector(model_2, imgs_name[index-step:num])
  52. results_1 += results_1_tmp
  53. results_2 += results_2_tmp
  54. print(len(results_1), len(results_2))
  55. print("after infer")
  56. imgs_results_1 = []
  57. for result in results_1:
  58. res_predict = 0
  59. for i in result:
  60. if i.shape[0]>0:
  61. res_predict = 1
  62. imgs_results_1.append(res_predict)
  63. imgs_results_2 = []
  64. for result in results_2:
  65. res_predict = 0
  66. for i in result:
  67. if i.shape[0]>0:
  68. res_predict = 1
  69. imgs_results_2.append(res_predict)
  70. print(len(imgs_results_1), len(imgs_results_2))
  71. imgs_results = []
  72. for r1, r2 in zip(imgs_results_1, imgs_results_2):
  73. if r1 == 1 or r2 == 1:
  74. imgs_results.append(1)
  75. else:
  76. imgs_results.append(0)
  77. print(len(imgs_results_1), len(imgs_results_2), len(imgs_results))
  78. count_ng = 0
  79. count_ok = 0
  80. for i in range(len(imgs_labels)):
  81. if imgs_labels[i]==0 and imgs_results_1[i]==0:
  82. count_ok += 1
  83. if imgs_labels[i]==1 and imgs_results_1[i]==1:
  84. count_ng += 1
  85. if imgs_labels[i]==1 and imgs_results_1[i]==0:
  86. print(imgs_name[i])
  87. print(count_ok/count_label_ok, count_ng/count_label_ng)
  88. count_ng = 0
  89. count_ok = 0
  90. for i in range(len(imgs_labels)):
  91. if imgs_labels[i]==0 and imgs_results_2[i]==0:
  92. count_ok += 1
  93. if imgs_labels[i]==1 and imgs_results_2[i]==1:
  94. count_ng += 1
  95. if imgs_labels[i]==1 and imgs_results_2[i]==0:
  96. print(imgs_name[i])
  97. print(count_ok/count_label_ok, count_ng/count_label_ng)
  98. count_ng = 0
  99. count_ok = 0
  100. for i in range(len(imgs_labels)):
  101. if imgs_labels[i]==0 and imgs_results[i]==0:
  102. count_ok += 1
  103. if imgs_labels[i]==1 and imgs_results[i]==1:
  104. count_ng += 1
  105. if imgs_labels[i]==1 and imgs_results[i]==0:
  106. print(imgs_name[i])
  107. #print(count_ok, count_ng)
  108. #print(count_ng/label_ng)
  109. print(count_ok/count_label_ok, count_ng/count_label_ng)

No Description

Contributors (3)