You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test01.py 3.5 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import sys
  2. import os
  3. import random
  4. sys.path.append("/home/shanwei-luo/userdata/mmdetection")
  5. from mmdet.apis import (async_inference_detector, inference_detector,
  6. init_detector, show_result_pyplot)
  7. config_file_1 = '/home/shanwei-luo/userdata/mmdetection/work_dirs/AD_dsxw_test63/AD_dsxw_test63.py'
  8. checkpoint_file_1 = '/home/shanwei-luo/userdata/mmdetection/work_dirs/AD_dsxw_test63/epoch_46.pth'
  9. config_file_2 = '/home/shanwei-luo/userdata/mmdetection/work_dirs/AD_dsxw_test62/AD_dsxw_test62.py'
  10. checkpoint_file_2 = '/home/shanwei-luo/userdata/mmdetection/work_dirs/AD_dsxw_test62/epoch_44.pth'
  11. img_path = '/home/shanwei-luo/userdata/datasets/dsxw_test_SMD11_2022_0301_0315/images/'
  12. label_path = '/home/shanwei-luo/userdata/datasets/dsxw_test_SMD11_2022_0301_0315/labels/'
  13. model_1 = init_detector(config_file_1, checkpoint_file_1, device='cuda:0')
  14. model_2 = init_detector(config_file_2, checkpoint_file_2, device='cuda:1')
  15. imgs = os.listdir(img_path)
  16. labels = os.listdir(label_path)
  17. #img_id = random.randint(0, len(label_path))
  18. label_ng = len(labels)
  19. label_ok = len(imgs)-label_ng
  20. print(label_ok, label_ng)
  21. imgs_labels = []
  22. imgs_name = []
  23. for img in imgs:
  24. label = img[:-3]+'txt'
  25. res_label = 0
  26. if label in labels:
  27. res_label = 1
  28. imgs_labels.append(res_label)
  29. imgs_name.append(img_path+img)
  30. print(len(imgs_labels))
  31. print("before infer")
  32. index = 0
  33. num = len(imgs_name)
  34. results_1 = []
  35. results_2 = []
  36. step = 32
  37. while index<num:
  38. index += step
  39. if index < num:
  40. results_1_tmp = inference_detector(model_1, imgs_name[index-step:index])
  41. results_2_tmp = inference_detector(model_2, imgs_name[index-step:index])
  42. else:
  43. results_1_tmp = inference_detector(model_1, imgs_name[index-step:num])
  44. results_2_tmp = inference_detector(model_2, imgs_name[index-step:num])
  45. results_1 += results_1_tmp
  46. results_2 += results_2_tmp
  47. print(len(results_1), len(results_2))
  48. print("after infer")
  49. imgs_results_1 = []
  50. for result in results_1:
  51. res_predict = 0
  52. for i in result:
  53. if i.shape[0]>0:
  54. res_predict = 1
  55. imgs_results_1.append(res_predict)
  56. imgs_results_2 = []
  57. for result in results_2:
  58. res_predict = 0
  59. for i in result:
  60. if i.shape[0]>0:
  61. res_predict = 1
  62. imgs_results_2.append(res_predict)
  63. print(len(imgs_results_1), len(imgs_results_2))
  64. imgs_results = []
  65. for r1, r2 in zip(imgs_results_1, imgs_results_2):
  66. if r1 == 1 or r2 == 1:
  67. imgs_results.append(1)
  68. else:
  69. imgs_results.append(0)
  70. print(len(imgs_results))
  71. count_ng = 0
  72. count_ok = 0
  73. for i in range(len(imgs_labels)):
  74. if imgs_labels[i]==0 and imgs_labels[i]==imgs_results_1[i]:
  75. count_ok += 1
  76. if imgs_labels[i]==1 and imgs_labels[i]==imgs_results_1[i]:
  77. count_ng += 1
  78. print(count_ok/label_ok, count_ng/label_ng)
  79. count_ng = 0
  80. count_ok = 0
  81. for i in range(len(imgs_labels)):
  82. if imgs_labels[i]==0 and imgs_labels[i]==imgs_results_2[i]:
  83. count_ok += 1
  84. if imgs_labels[i]==1 and imgs_labels[i]==imgs_results_2[i]:
  85. count_ng += 1
  86. print(count_ok/label_ok, count_ng/label_ng)
  87. count_ng = 0
  88. count_ok = 0
  89. for i in range(len(imgs_labels)):
  90. if imgs_labels[i]==0 and imgs_labels[i]==imgs_results[i]:
  91. count_ok += 1
  92. if imgs_labels[i]==1 and imgs_labels[i]==imgs_results[i]:
  93. count_ng += 1
  94. #print(count_ok, count_ng)
  95. #print(count_ng/label_ng)
  96. print(count_ok/label_ok, count_ng/label_ng)

No Description

Contributors (3)