You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test08.py 2.6 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import sys
  2. import os
  3. import random
  4. import numpy as np
  5. import shutil
  6. sys.path.append("/home/shanwei-luo/userdata/mmdetection")
  7. from mmdet.apis import (async_inference_detector, inference_detector,
  8. init_detector, show_result_pyplot)
  9. config_file_1 = '/home/shanwei-luo/userdata/mmdetection/work_dirs/AD_dsxw_test70/AD_dsxw_test70.py'
  10. checkpoint_file_1 = '/home/shanwei-luo/userdata/mmdetection/work_dirs/AD_dsxw_test70/epoch_38.pth'
  11. img_path = '/home/shanwei-luo/userdata/datasets/PCBA_dataset_v15_MLOPS/ok/'
  12. dis_path1 = '/home/shanwei-luo/userdata/datasets/PCBA_dataset_v15_MLOPS/ok_score/'
  13. dis_path2 = '/home/shanwei-luo/userdata/datasets/PCBA_dataset_v15_MLOPS/v1/dsxw_train/images/'
  14. model_1 = init_detector(config_file_1, checkpoint_file_1, device='cuda:0')
  15. imgs = os.listdir(img_path)
  16. imgs_name = []
  17. for img in imgs:
  18. imgs_name.append(img_path+img)
  19. print(len(imgs_name))
  20. print("before infer")
  21. index = 0
  22. num = len(imgs_name)
  23. results_1 = []
  24. step = 128
  25. while index<num:
  26. index += step
  27. if index < num:
  28. results_1_tmp = inference_detector(model_1, imgs_name[index-step:index])
  29. else:
  30. results_1_tmp = inference_detector(model_1, imgs_name[index-step:num])
  31. results_1 += results_1_tmp
  32. print(len(results_1))
  33. print("after infer")
  34. result_score = []
  35. for result in results_1:
  36. max_v = 0
  37. for i in result:
  38. for j in range(i.shape[0]):
  39. if max_v < i[j, 4]:
  40. max_v = i[j, 4]
  41. result_score.append(max_v)
  42. np_result = np.array(result_score)
  43. print(np_result.shape)
  44. np.save('/home/shanwei-luo/userdata/datasets/PCBA_dataset_v15_MLOPS/result.npy', np_result)
  45. score_thr = 0.07
  46. for k_v in np.arange(0.5, 1, 0.05):
  47. count = 0
  48. index = 0
  49. for result in results_1:
  50. hard = 0
  51. for i in result:
  52. for j in range(i.shape[0]):
  53. '''if i[j, 4]>=score_thr-score_thr*k_v and i[j, 4]<=score_thr+(1-score_thr)*k_v:
  54. hard = 1'''
  55. if i[j, 4]>=k_v:
  56. hard = 1
  57. if hard==1:
  58. count+=1
  59. #print(img_path+imgs[index], dis_path1+imgs[index])
  60. #shutil.copy(img_path+imgs[index], dis_path1+imgs[index])
  61. #shutil.copy(img_path+imgs[index], dis_path2+imgs[index])
  62. index += 1
  63. print(k_v, count)
  64. np_result = np.load('/home/shanwei-luo/userdata/datasets/PCBA_dataset_v15_MLOPS/result.npy')
  65. result_score = np_result.tolist()
  66. for k_v in np.arange(0.5, 1, 0.05):
  67. count = 0
  68. index = 0
  69. for result in result_score:
  70. if result>=k_v:
  71. count +=1
  72. index += 1
  73. print(k_v, count)

No Description

Contributors (1)