You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test02.py 2.1 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import sys
  2. import os
  3. import random
  4. sys.path.append("/home/shanwei-luo/userdata/mmdetection")
  5. from mmdet.apis import (async_inference_detector, inference_detector,
  6. init_detector, show_result_pyplot)
  7. config_file_1 = '/home/shanwei-luo/userdata/mmdetection/work_dirs/AD_dsxw_test50/AD_dsxw_test50.py'
  8. checkpoint_file_1 = '/home/shanwei-luo/userdata/mmdetection/work_dirs/AD_dsxw_test50/epoch_50.pth'
  9. img_path = '/home/shanwei-luo/userdata/datasets/dsxw_test_0215_0228/images/'
  10. label_path = '/home/shanwei-luo/userdata/datasets/dsxw_test_0215_0228/labels/'
  11. model_1 = init_detector(config_file_1, checkpoint_file_1, device='cuda:1')
  12. imgs = os.listdir(img_path)
  13. labels = os.listdir(label_path)
  14. #img_id = random.randint(0, len(label_path))
  15. label_ng = len(labels)
  16. label_ok = len(imgs)-label_ng
  17. print(label_ok, label_ng)
  18. imgs_labels = []
  19. imgs_name = []
  20. #imgs = imgs[:40]
  21. for img in imgs:
  22. label = img[:-3]+'txt'
  23. res_label = 0
  24. if label in labels:
  25. res_label = 1
  26. imgs_labels.append(res_label)
  27. imgs_name.append(img_path+img)
  28. print(len(imgs_labels))
  29. print("before infer")
  30. index = 0
  31. num = len(imgs_name)
  32. results_1 = []
  33. results_2 = []
  34. step = 32
  35. while index<num:
  36. index += step
  37. if index < num:
  38. results_1_tmp = inference_detector(model_1, imgs_name[index-step:index])
  39. else:
  40. results_1_tmp = inference_detector(model_1, imgs_name[index-step:num])
  41. results_1 += results_1_tmp
  42. print(len(results_1))
  43. print("after infer")
  44. imgs_results_1 = []
  45. for result in results_1:
  46. res_predict = 0
  47. for i in result:
  48. if i.shape[0]>0:
  49. res_predict = 1
  50. imgs_results_1.append(res_predict)
  51. print(len(imgs_results_1))
  52. count_ok_error = 0
  53. f = open("ok_error_SMD12.txt", "w")
  54. for i in range(len(imgs_labels)):
  55. if imgs_labels[i]==0 and imgs_results_1[i]==1:
  56. count_ok_error += 1
  57. print(imgs[i])
  58. f.write(imgs[i]+"\n")
  59. '''if imgs_labels[i]==1 and imgs_labels[i]==imgs_results_1[i]:
  60. count_ng += 1'''
  61. f.close()
  62. print(count_ok_error, count_ok_error/label_ok)

No Description

Contributors (1)